In this post, I describe how a handwritten digit classifier can be trained in Kotlin/Native using Torch. Like other Torch clients, most prominently PyTorch, this example is built on top of the ATen C API, showing how a Torch client for Kotlin/Native could look like. In my last post, I already explained how the TensorFlow backend can used from Kotlin/Native, I will mainly focus on the differences of Torch.
Torch is different from TensorFlow in that its backend does not split compute graph definition and execution into separate concepts, as is done by TensorFlow with graphs and sessions. Defining graphs separately from running them creates some inflexibility, in that we have to define a new graph when we want to perform different operations. (TensorFlow Eager is being developed to tackle this.) Instead, Torch directly provides the following set of functions for each network operation with parameters :
updateOutputreturns the output of the module given the input
updateGradInputreturns the gradient of some loss with respect to the input given , and the gradient with respect to the output
- [If the module has parameters]
accGradParametersreturns the gradient of with respect to the parameters given , , and .
To know in which direction we need to nudge all network parameters to decrease some loss , we are interested in the gradient of with respect to all network parameters , which we can obtain using the backpropagation algorithm, consisting of two passes:
- In the forward pass, we calculate all layer outputs starting from the network’s inputs, sequentially applying
updateOutputto the respective previous output.
- In the backwards pass, we calculate all layer outputs starting from the loss , sequentially applying
updateGradInputto get the respective gradient of with respect to all inputs. We get the desired gradients by calling
accGradParametersfor all modules that have parameters.
In reality, the said functions can have additional parameters, e. g. for the library state, buffers and configuration. For example, the function signatures for the linear module are:
Note that outputs are also passed in as parameters, as memory allocation for the result is responsibility of the caller.
The Torch API is split into multiple headers, of which we will use two (specified in
TH contains basic tensor types and operations and
THNN contains the functions for network modules described above. I used the Torch version provided in the PyTorch repository, called ATen, as it comes with a streamlined process to build all necessary libraries at once, with a built-in switch to support both CPU and GPU.
The rest of the build process is nearly identical to the one for the TensorFlow demo from the last post.
For convenience, we wrap the raw
THFloatTensor into a class, which is straight forward:
To enable compile-time dimensionality checks and bette tooling (auto-completion, …), I defined the tensor class as abstract, and created vector and matrix subtypes with dimensionality-specific functions:
Herby, creating vectors (similar for matrices) is implemented through
To allow a general implementation of backpropagation as described above, I defined a generic base class
Module<Input, Output, Gradient>, from which all types of network modules inherit, for example the linear layer:
CrossEntropyLoss are implemented similarly, but don’t have parameters, e. g.
After adding some convenience functions, we can define a two-layer network through:
Finally, the function for training an MNIST classifier looks like this (slightly simplified):
Kotlin/Native automatically manages memory based on automated reference counting with cycle collection. When using external C libraries though, corresponding memory has to be allocated and released manually. To make this easier, I implemented the following mechanism (power by Kotlin’s type-safe builders): All expressions registered with
use will be
disposed within scopes marked by the
disposeScoped function call. In our case, this makes sure that all forward and backward pass results are disposed from memory after every iteration, to keep memory usage low.
Finally we run the code to train a classifier for handwritten digits from the MNIST dataset:
The dataset is downloaded by the build script and parsed through the
Even on a CPU, training should only take some minutes, and you should observe a classification accuracy of about 95% on the test dataset.
Is this viable?
While this project was aimed at exploring the internals of TensorFlow and Torch as well as development in Kotlin/Native, you might wonder whether it would be useful to extend this demo into a fully-fledged framework for machine learning. Here are some points to consider:
- There is a massive existing community behind TensorFlow’s Python client and PyTorch, creating a large pressure of further development. Potential alternatives would lack this, which reduces their viability. Additionally, Python has an evolved eco-system of libraries for data processing and machine learning libraries, which is further increases this effect.
- On the positive side, native binaries can be deployed on a wider range of platforms and hardware, enabling embedded (robotics, …), web (WebAssembly, …) and mobile applications. But because training of large models is hardware-intensive anyway (often performed on server GPUs), deployment of pre-trained networks would be the main benefit here. This use case is covered by e. g. TensorFlow Lite already.
- Parallel processing in Python itself is cumbersome and typically requires multiple processes due to the global intepreter lock. This makes large-scale data processing harder to optimize within the language itself, a task particularly important for machine learning applications. Performance within Python is further degraded due to Python being a dynamically typed and typically interpreted instead of compiled into native binaries. This makes native bindings necessary for highly optimized code (such as used by Torch and TensorFlow). Kotlin/Native would offer a modern language for developing native code without these restrictions, although other modern languages (such as Rust) have similar benefits.
- Even for the development of this small demo, static typing was helpful to me, wiping out a large class of errors at compile-time and enabling strong tooling (auto-complete, refactoring, …). It is not obvious to me though whether these benefits would scale in a larger-scale implementation. Also, keeping track of finely-graded types (e. g. vectors vs. matrices, float vs. double) might arguably require more thought in designing generic APIs.
- Finally, some of Python’s language features were designed with numeric processing in mind, and Kotlin might benefit from some syntactic sugar to keep up (e. g. for defining arrays, float vs. double literals, array slicing, recently: introducing the
Having said all that: To me, building this demo felt like C development on a rocket booster, and I hope that Kotlin/Native will find its use case in machine learning.
The full code is available in the Kotlin/Native repository along with instructions for how to run it. If you have questions or feedback, please comment below.