In this post, I describe how a handwritten digit classifier can be trained in Kotlin/Native using Torch. Like other Torch clients, most prominently PyTorch, this example is built on top of the ATen C API, showing how a Torch client for Kotlin/Native could look like. In my last post, I already explained how the TensorFlow backend can used from Kotlin/Native, I will mainly focus on the differences of Torch.

Torch is different from TensorFlow in that its backend does not split compute graph definition and execution into separate concepts, as is done by TensorFlow with graphs and sessions. Defining graphs separately from running them creates some inflexibility, in that we have to define a new graph when we want to perform different operations. (TensorFlow Eager is being developed to tackle this.) Instead, Torch directly provides the following set of functions for each network operation with parameters :

  • updateOutput returns the output of the module given the input
  • updateGradInput returns the gradient of some loss with respect to the input given , and the gradient with respect to the output
  • [If the module has parameters] accGradParameters returns the gradient of with respect to the parameters given , , and .

To know in which direction we need to nudge all network parameters to decrease some loss , we are interested in the gradient of with respect to all network parameters , which we can obtain using the backpropagation algorithm, consisting of two passes:

  • In the forward pass, we calculate all layer outputs starting from the network’s inputs, sequentially applying updateOutput to the respective previous output.
  • In the backwards pass, we calculate all layer outputs starting from the loss , sequentially applying updateGradInput to get the respective gradient of with respect to all inputs. We get the desired gradients by calling accGradParameters for all modules that have parameters.

In reality, the said functions can have additional parameters, e. g. for the library state, buffers and configuration. For example, the function signatures for the linear module are:

void THNN_Linear_updateOutput(
          THNNState *state,
          THTensor *input,
          THTensor *output,
          THTensor *weight,
          THTensor *bias,
          THTensor *addBuffer)
          
void THNN_Linear_updateGradInput(
          THNNState *state,
          THTensor *input,
          THTensor *gradOutput,
          THTensor *gradInput,
          THTensor *weight)

void THNN_Linear_accGradParameters(
          THNNState *state,
          THTensor *input,
          THTensor *gradOutput,
          THTensor *gradInput,
          THTensor *weight,
          THTensor *bias,
          THTensor *gradWeight,
          THTensor *gradBias,
          THTensor *addBuffer,
          accreal scale_)

Note that outputs are also passed in as parameters, as memory allocation for the result is responsibility of the caller.

The Torch API is split into multiple headers, of which we will use two (specified in torch.def):

headers = TH/TH.h THNN/THNN.h

TH contains basic tensor types and operations and THNN contains the functions for network modules described above. I used the Torch version provided in the PyTorch repository, called ATen, as it comes with a streamlined process to build all necessary libraries at once, with a built-in switch to support both CPU and GPU.

The rest of the build process is nearly identical to the one for the TensorFlow demo from the last post.

Writing code

For convenience, we wrap the raw THFloatTensor into a class, which is straight forward:

abstract class FloatTensor(val raw: CPointer<THFloatTensor>) : Disposable {
    private val storage: CPointer<THFloatStorage> get() = raw.pointed.storage!!
    private val elements get() = storage.pointed
    private val data: CPointer<FloatVar> get() = elements.data!!
    private val size: CPointer<LongVar> get() = raw.pointed.size!!
    protected val nDimension: Int get() = raw.pointed.nDimension

    val shape: List<Int> get() = (0 until nDimension).map { size[it].toInt() }

    operator fun plus(other: FloatTensor) = initializedTensor(shape) {
        THFloatTensor_cadd(it.raw, raw, 1f, other.raw)
    }

    operator fun minus(other: FloatTensor) = initializedTensor(shape) {
        THFloatTensor_cadd(it.raw, raw, -1f, other.raw)
    }

    open operator fun times(factor: Float) = initializedTensor(shape) {
        THFloatTensor_mul(it.raw, raw, factor)
    }

    fun sum() = THFloatTensor_sumall(raw)
    fun flatten() = (0 until elements.size).map { data[it] }.toTypedArray()

    override fun dispose() {
        THFloatTensor_free(raw)
    }
    
    ...

To enable compile-time dimensionality checks and bette tooling (auto-completion, …), I defined the tensor class as abstract, and created vector and matrix subtypes with dimensionality-specific functions:

class FloatVector(raw: CPointer<THFloatTensor>) : FloatTensor(raw) {
    operator fun times(other: FloatVector) = THFloatTensor_dot(raw, other.raw)
    
    ...
}


class FloatMatrix(raw: CPointer<THFloatTensor>) : FloatTensor(raw) {
    operator fun times(vector: FloatVector) = initializedTensor(shape[0]) {
        THFloatTensor_addmv(it.raw, 0f, it.raw, 1f, raw, vector.raw)
    }

    operator fun times(matrix: FloatMatrix) = initializedTensor(shape[0], matrix.shape[1]) {
        THFloatTensor_addmm(it.raw, 0f, it.raw, 1f, raw, matrix.raw)
    }
    
    ...
}

Herby, creating vectors (similar for matrices) is implemented through

fun uninitializedTensor(size: Int) = 
    FloatVector(THFloatTensor_newWithSize1d(size.signExtend())!!)

fun <T> initializedTensor(size: Int, initializer: (FloatVector) -> T) =
    uninitializedTensor(size).apply { initializer(this) }

To allow a general implementation of backpropagation as described above, I defined a generic base class Module<Input, Output, Gradient>, from which all types of network modules inherit, for example the linear layer:

data class Linear(
        var weight: FloatMatrix,
        var bias: FloatVector) : Module<FloatMatrix, FloatMatrix, Pair<FloatMatrix, FloatVector>>() {
    val inputSize = weight.shape[1]
    val outputSize = weight.shape[0]
    val addBuffer = uninitializedTensor(outputSize)

    override operator fun invoke(input: FloatMatrix) = initializedTensor(input.shape[0], outputSize) {
        THNN_FloatLinear_updateOutput(null, input.raw, it.raw, weight.raw, bias.raw, addBuffer.raw)
    }

    override fun inputGradient(input: FloatMatrix, outputGradient: FloatMatrix, output: FloatMatrix) =
            initializedTensor(input.shape[0], inputSize) {
                THNN_FloatLinear_updateGradInput(null, input.raw, outputGradient.raw, it.raw, weight.raw)
            }

    override fun parameterGradient(
            input: FloatMatrix,
            outputGradient: FloatMatrix,
            inputGradient: FloatMatrix
    ): Pair<FloatMatrix, FloatVector> {
        val biasGradient = zeros(outputSize)
        val weightGradient = zeros(weight.shape[0], weight.shape[1]).also {
            THNN_FloatLinear_accGradParameters(
                null, input.raw, outputGradient.raw, inputGradient.raw, weight.raw,
                bias.raw, it.raw, biasGradient.raw, addBuffer.raw, 1.0)
        }

        return weightGradient to biasGradient
    }
    
    ...

Relu, Softmax and CrossEntropyLoss are implemented similarly, but don’t have parameters, e. g.

object Relu : ParameterFreeModule<FloatMatrix, FloatMatrix>() {
    override operator fun invoke(input: FloatMatrix) = initializedTensor(input.shape[0], input.shape[1]) {
        THNN_FloatLeakyReLU_updateOutput(null, input.raw, it.raw, 0.0, false)
    }

    override fun inputGradient(input: FloatMatrix, outputGradient: FloatMatrix, output: FloatMatrix) =
        initializedTensor(input.shape[0], input.shape[1]) {
            THNN_FloatLeakyReLU_updateGradInput(null, input.raw, outputGradient.raw, it.raw, 0.0, false)
        }
}

After adding some convenience functions, we can define a two-layer network through:

fun linear(inputSize: Int, outputSize: Int) = 
    Linear(randomInit(outputSize, inputSize), randomInit(outputSize))
fun twoLayerClassifier(dataset: Dataset, hiddenSize: Int = 64) =
    linear(dataset.inputs[0].size, hiddenSize) before Relu before
    linear(hiddenSize, dataset.labels[0].size) before Softmax

Finally, the function for training an MNIST classifier looks like this (slightly simplified):

fun Backpropagatable<FloatMatrix, FloatMatrix>.trainClassifier(
        dataset: Dataset,
        lossByLabels: (FloatMatrix) -> Backpropagatable<FloatMatrix, FloatVector>  = { CrossEntropyLoss(labels = it) },
        learningRateByProgress: (Float) -> Float = { 5f * kotlin.math.exp(-it * 3) },
        batchSize: Int = 64,
        iterations: Int = 500) {
    for (i in 0 until iterations) {
        disposeScoped {
            val (inputBatch, labelBatch) = dataset.sampleBatch(batchSize)
            val errorNetwork = this@trainClassifier before lossByLabels(labelBatch)
            val forwardResults = use { errorNetwork.forwardPass(inputBatch) }
            val learningRate = learningRateByProgress(progress)
            val backpropResults = use { forwardResults.backpropagate(outputGradient = tensor(learningRate)) }
            backpropResults.descend()
        }
    }

Kotlin/Native automatically manages memory based on automated reference counting with cycle collection. When using external C libraries though, corresponding memory has to be allocated and released manually. To make this easier, I implemented the following mechanism (power by Kotlin’s type-safe builders): All expressions registered with use will be disposed within scopes marked by the disposeScoped function call. In our case, this makes sure that all forward and backward pass results are disposed from memory after every iteration, to keep memory usage low.

Finally we run the code to train a classifier for handwritten digits from the MNIST dataset:

val trainingDataset = MNIST.labeledTrainingImages()
val predictionNetwork = twoLayerClassifier(trainingDataset)
predictionNetwork.trainClassifier(trainingDataset)

The dataset is downloaded by the build script and parsed through the MNIST class.

Even on a CPU, training should only take some minutes, and you should observe a classification accuracy of about 95% on the test dataset.

Is this viable?

While this project was aimed at exploring the internals of TensorFlow and Torch as well as development in Kotlin/Native, you might wonder whether it would be useful to extend this demo into a fully-fledged framework for machine learning. Here are some points to consider:

  • There is a massive existing community behind TensorFlow’s Python client and PyTorch, creating a large pressure of further development. Potential alternatives would lack this, which reduces their viability. Additionally, Python has an evolved eco-system of libraries for data processing and machine learning libraries, which is further increases this effect.
  • On the positive side, native binaries can be deployed on a wider range of platforms and hardware, enabling embedded (robotics, …), web (WebAssembly, …) and mobile applications. But because training of large models is hardware-intensive anyway (often performed on server GPUs), deployment of pre-trained networks would be the main benefit here. This use case is covered by e. g. TensorFlow Lite already.
  • Parallel processing in Python itself is cumbersome and typically requires multiple processes due to the global intepreter lock. This makes large-scale data processing harder to optimize within the language itself, a task particularly important for machine learning applications. Performance within Python is further degraded due to Python being a dynamically typed and typically interpreted instead of compiled into native binaries. This makes native bindings necessary for highly optimized code (such as used by Torch and TensorFlow). Kotlin/Native would offer a modern language for developing native code without these restrictions, although other modern languages (such as Rust) have similar benefits.
  • Even for the development of this small demo, static typing was helpful to me, wiping out a large class of errors at compile-time and enabling strong tooling (auto-complete, refactoring, …). It is not obvious to me though whether these benefits would scale in a larger-scale implementation. Also, keeping track of finely-graded types (e. g. vectors vs. matrices, float vs. double) might arguably require more thought in designing generic APIs.
  • Finally, some of Python’s language features were designed with numeric processing in mind, and Kotlin might benefit from some syntactic sugar to keep up (e. g. for defining arrays, float vs. double literals, array slicing, recently: introducing the @-operator).

Having said all that: To me, building this demo felt like C development on a rocket booster, and I hope that Kotlin/Native will find its use case in machine learning.

The full code is available in the Kotlin/Native repository along with instructions for how to run it. If you have questions or feedback, please comment below.