From Python to Elixir Machine Learning
Python PyTorch to Elixir-Nx

From Python to Elixir Machine Learning

Moving on from Python Machine Learning might seem impossible. Let me break down why and how you can do it.

As Elixir's Machine Learning (ML) ecosystem grows, many Elixir enthusiasts who wish to adopt the new machine learning libraries in their projects are stuck at a crossroads of wanting to move away from their existing ML stack (typically Python) while not having a clear path of how to do so. I would like to take some time to talk to WHY I believe now is a good time to start porting over Machine Learning code into Elixir, and HOW I went about doing just this for two libraries I wrote: EXGBoost (from Python XGBoost) and Mockingjay (from Python Hummingbird).

Why is Python not Sufficient?

There's a common saying in programming languages that no language is perfect, but that different languages are suited for different jobs. Languages such as C, Rust, and now even Zig are known for their targeting systems development, while languages such as C++, C#, and Java are more commonly used for application development, and obviously there are the web languages such as JavaScript/TypeScript, PHP, Ruby (on Rails), and more. There are gradations to these rules of course, but more often than not there are good reasons that languages tend to exist within the confines of particular use cases.

Languages such as Elixir and Go tend to be used in large distributed systems because they place an emphasis on having great support for common concurrency patterns, which can come at the cost of supporting other domains. Go, for example, has barely (if any?) support for machine learning libraries, but it's also not trying to cater to that as a target domain. For a long time, the same could have been said about Elixir, but over the past two or so years, there has been a massive concerted push from the Elixir community to not only have support for machine learning, but to push the envelope with the maintaining state of the art libraries that are beginning to compete with the other dominant machine learning languages - namely Python.

Python has long been the gold standard in the realm of machine learning. The breadth of libraries and the low entry barrier makes Python a great language to work with, but it does create a bit of a bottleneck. Any application that wishes to integrate machine learning has historically had only a couple of options: have a Python component or reach into the underlying libraries that power much of the Python libraries directly. Despite all the good parts of Python I mentioned before, speed and support for concurrency are not on that list. Elixir-Nx is striving to give another option - an option that can take advantage of the native distributed support that Elixir and the BEAM VM have to offer. Nx's Nx.Serving construct is a drop-in solution for serving distributed machine-learning models.

How to Proceed

Sean Moriarity, the co-creator of Nx, creator of Axon, and author of Machine Learning in Elixir, has talked many times about how the initial creation of Nx and Axon involved hours upon hours of reading source code from reference implementations of libraries in Python and C++, namely the TensorFlow source code. While I was writing EXGBoost and Mockingjay, much of my time, especially towards the beginning, was spent referencing the Python and C++ implementations of the original libraries. This builds a great fundamental understanding of the libraries as well as taught me how to identify patterns in Python and C++ and identify the Elixir pattern that could express the same ideas. This skill is invaluable, and the better I got at it the faster I could write. Below is a summary and key takeaways from my process of porting Python / PyTorch to Elixir / Nx.

Workflow Overview

Before I get to the examples from the code bases, I would like to briefly explain the high-level cyclical workflow I established while working on this effort, and what I would recommend to anyone pursuing a similar endeavor.

Understand the Macro System

Much like how there's a common strategy to reading comprehension which involves reading through the entire document once to get a high-level understanding and then doing subsequent shorter reads to gain more in-depth understanding with the added context of the entire piece, you can consider doing the same when reading code. My first step was to follow the logical flow from the call of to the final result. You can use tools such as function tracers and callgraph generators to accelerate this part of the process, or manually trace depending on the extent of the codebase. I felt in my case that it was manageable to trace myself.

Read the Documentation

Once you have a general understanding of the flow and process of the original system, you can start referring to the documentation for some additional context. In my case, this lead me to the academic paper Taming Model Serving Complexity, Performance and Cost: A Compilation to Tensor Computations Approach, which was the underlying ground work and basis for their implementation. I could write a whole other blog post about the process of transcribing algorithms and code from academic papers and pseudocode, but for now just know that these are some of the most important pieces you can refer to while re-implementing or porting over a piece of source code.

Read the Source Code in Detail

This is the point in which you want to disambiguate the higher-level ideas from the first step and really gain a fine, high-resolution understanding of what is happening. There might even be some points in which you need to deconflict the source code with its documentation and/or paper reference. In those cases, the source code almost always wins, and if not, then you likely have a bug report you can file. If you see things you don't fully understand, you don't necessarily need to address it here, but you should make note of it and keep it in mind while working in case new details help resolve it.

Implement the New Code

At this point, you should feel comfortable enough to start implementing the code. I found this to be a very iterative process, meaning I would think I had a grasp on something, then would start working on implementing it, then would realize I did not understand it as well as I had thought and would work my way back through the previous steps.


In case you would like to follow along going forward, the Python code I will be referencing is the Microsoft Hummingbird source code (specifically their implementation of Decision Tree Compilation), and the Elixir code is from the Mockingjay source code.

Class vs. Behaviour

As a result of the reading and comprehension I did of the Hummingbird code base, I realized fairly early on that my library was going to have some key differences. One of the main reasons for these differences was the fact that the Hummingbird code base was built as a retroactive library that needed to cater to existing APIs that existed throughout the Python ecosystem. They chose to only add support for converting decision trees according to the SKLearn API. I, conversely, chose to write Mockingjay in such a way that it would be incumbent upon the authors of decision tree libraries to implement a protocol to interface with Mockingjay's convert function. This difference meant that I could establish a Mockingjay.Tree data structure that I would use throughout my library, rather than having to reconstruct tree features from various other APIs as is done in Hummingbird.

Next, Hummingbird approaches its pipeline in a very-object oriented manner, as makes sense when using Python. Here' we are focusing on the implementation of the three decision tree conversion strategies: GEMM, Tree Traversal, and PErfect Tree Traversal. It implements the following base class for tree conversions as well as PyTorch networks.

Since they're inheriting from torch.nn.model they must also implement the forward method.
class AbstracTreeImpl(PhysicalOperator):
    Abstract class definig the basic structure for tree-base models.

    def __init__(self, logical_operator, **kwargs):
        super().__init__(logical_operator, **kwargs)

    def aggregation(self, x):
        Method defining the aggregation operation to execute after the model is evaluated.

            x: An input tensor

            The tensor result of the aggregation

class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
    Abstract class definig the basic structure for tree-base models implemented in PyTorch.

    def __init__(
        self, logical_operator, tree_parameters, n_features, classes, n_classes, decision_cond="<=", extra_config={}, **kwargs
            tree_parameters: The parameters defining the tree structure
            n_features: The number of features input to the model
            classes: The classes used for classification. None if implementing a regression model
            n_classes: The total number of used classes
            decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
        super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs)

They then proceed to inherit from these base classes and have different classes for each of the three decision tree strategies as well as their gradient-boosted counterparts, leaving them with three classes for each strategies (1 base class per strategy, 1 for ensemble implementations, and 1 for normal implementations) and nine total classes.

I chose to approach this using a behaviour

defmodule Mockingjay.Strategy do
  @moduledoc false
  @type t :: Nx.Container.t()

  @callback init(data :: any(), opts :: Keyword.t()) :: term()
  @callback forward(x :: Nx.Container.t(), term()) :: Nx.Tensor.t()

init will perform setup functionality depending on the strategy and return the parameters that will need to be passed to forward later on. This allows for a very simple top-level api. The whole top-level mockingjay.ex file can fit here:

  def convert(data, opts \\ []) do
    {strategy, opts} = Keyword.pop(opts, :strategy, :auto)

    strategy =
      case strategy do
        :gemm ->

        :tree_traversal ->

        :perfect_tree_traversal ->

        :auto ->
          Mockingjay.Strategy.get_strategy(data, opts)

        _ ->
          raise ArgumentError,
                "strategy must be one of :gemm, :tree_traversal, :perfect_tree_traversal, or :auto"

    {post_transform, opts} = Keyword.pop(opts, :post_transform, nil)
    state = strategy.init(data, opts)

    fn data ->
      result = strategy.forward(data, state)
      {_, n_trees, n_classes} = Nx.shape(result)

      |> aggregate(n_trees, n_classes)
      |> post_transform(post_transform, n_classes)

As you can see, the use of a behaviour here allows a strategy-agnostic approach to generating a prediction pipeline. In the object-oriented implementation, each class implements init, forward, aggregate, and post_transform. We get the same result from a functional pipeline approach, where each step generates the needed information as input parameters for the next step. So, instead of storing intermediate results as object properties or values in an object's __dict__, we just pass them along in the pipeline. I would argue this creates a much simpler and easier to follow implementation (but I am also quite biased).

PyTorch to Nx

For these examples, we will be looking at porting the implementations of the forward function for the three conversion strategies from Python to Nx.


Next, let's look at the forward function implementation for GEMM, one of the three conversion strategies. In Hummingbird, they implemented the forward step in the base class for each strategy. So given three GEMM classes with the signatures of GEMMTreeImpl(AbstractPyTorchTreeImpl), GEMMDecisionTreeImpl(GEMMTreeImpl), and GEMMGBDTImpl(GEMMTreeImpl), the forward function is defined in the GEMMTreeImpl class, since both ensemble and non-ensemble decision tree models share the same forward step.

def forward(self, x):
      x = x.t()
      x = self.decision_cond(, x), self.bias_1)
      x = x.view(self.n_trees, self.hidden_one_size, -1)
      x = x.float()

      x = torch.matmul(self.weight_2, x)

      x = x.view(self.n_trees * self.hidden_two_size, -1) == self.bias_2
      x = x.view(self.n_trees, self.hidden_two_size, -1)
      if self.tree_op_precision_dtype == "float32":
          x = x.float()
          x = x.double()

      x = torch.matmul(self.weight_3, x)
      x = x.view(self.n_trees, self.hidden_three_size, -1)

Now, here is the Nx implementation:

@impl true
  deftransform forward(x, {arg, opts}) do
    opts =
      Keyword.validate!(opts, [

    _forward(x, arg, opts)

  defnp _forward(x, arg, opts \\ []) do
    %{mat_A: mat_A, mat_B: mat_B, mat_C: mat_C, mat_D: mat_D, mat_E: mat_E} = arg

    condition = opts[:condition]
    n_trees = opts[:n_trees]
    n_classes = opts[:n_classes]
    max_decision_nodes = opts[:max_decision_nodes]
    max_leaf_nodes = opts[:max_leaf_nodes]
    n_weak_learner_classes = opts[:n_weak_learner_classes]

    |>[1], x, [1])
    |> condition.(mat_B)
    |> Nx.reshape({n_trees, max_decision_nodes, :auto})
    |> then(&, [2], [0], &1, [1], [0]))
    |> Nx.reshape({n_trees * max_leaf_nodes, :auto})
    |> Nx.equal(mat_D)
    |> Nx.reshape({n_trees, max_leaf_nodes, :auto})
    |> then(&, [2], [0], &1, [1], [0]))
    |> Nx.reshape({n_trees, n_weak_learner_classes, :auto})
    |> Nx.transpose()
    |> Nx.reshape({:auto, n_trees, n_classes})

Do not be distracted by the length of this code snippet, as much of the lines are taken up by validating arguments. Let's look at a more stripped-down version without that:

@impl true
  deftransform forward(x, {arg, opts}) do
    _forward(x, arg, opts)

  defnp _forward(x, arg, opts \\ []) do
    |>[1], x, [1])
    |> condition.(mat_B)
    |> Nx.reshape({n_trees, max_decision_nodes, :auto})
    |> then(&, [2], [0], &1, [1], [0]))
    |> Nx.reshape({n_trees * max_leaf_nodes, :auto})
    |> Nx.equal(mat_D)
    |> Nx.reshape({n_trees, max_leaf_nodes, :auto})
    |> then(&, [2], [0], &1, [1], [0]))
    |> Nx.reshape({n_trees, n_weak_learner_classes, :auto})
    |> Nx.transpose()
    |> Nx.reshape({:auto, n_trees, n_classes})

Let's take a look at some obvious difference:

  • The Nx code does not have to transpose in the first step since allows you to specify the contracting axes.
  • You can use to get the same behavior as torch.matmul
    • torch.matmul does a lot of wizardry with broadcasting to make this instance work
  • We use functions such as Nx.equal to fit into the pipeline rather than using the == oeprator (which would work outside of a pipeline)
  • torch.view is equivalent to Nx.reshape
  • Nx uses the :auto atom to where torch uses -1 to reference infering the sie of an axis

Outside of these differences, the code translates fairly easily. Let's take a look at a bit of a more complex instance.

Tree Traversal

Here is the Python implementation:

def _expand_indexes(self, batch_size):
        indexes = self.nodes_offset
        indexes = indexes.expand(batch_size, self.num_trees)
        return indexes.reshape(-1)

def forward(self, x):
        indexes = self.nodes_offset
        indexes = indexes.expand(batch_size, self.num_trees).reshape(-1)

        for _ in range(self.max_tree_depth):
            tree_nodes = indexes
            feature_nodes = torch.index_select(self.features, 0, tree_nodes).view(-1, self.num_trees)
            feature_values = torch.gather(x, 1, feature_nodes)

            thresholds = torch.index_select(self.thresholds, 0, indexes).view(-1, self.num_trees)
            lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
            rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)

            indexes = torch.where(self.decision_cond(feature_values, thresholds), lefts, rights).long()
            indexes = indexes + self.nodes_offset
            indexes = indexes.view(-1)

        output = torch.index_select(self.values, 0, indexes).view(-1, self.num_trees, self.n_classes)

And here is the Nx implementation:

defn _forward(x, features, lefts, rights, thresholds, nodes_offset, values, opts \\ []) do
    max_tree_depth = opts[:max_tree_depth]
    num_trees = opts[:num_trees]
    n_classes = opts[:n_classes]
    condition = opts[:condition]
    unroll = opts[:unroll]

    batch_size = Nx.axis_size(x, 0)

    indices =
      |> Nx.broadcast({batch_size, num_trees})
      |> Nx.reshape({:auto})

    {indices, _} =
      while {tree_nodes = indices, {features, lefts, rights, thresholds, nodes_offset, x}},
            _ <- 1..max_tree_depth,
            unroll: unroll do
        feature_nodes = Nx.take(features, tree_nodes) |> Nx.reshape({:auto, num_trees})
        feature_values = Nx.take_along_axis(x, feature_nodes, axis: 1)
        local_thresholds = Nx.take(thresholds, tree_nodes) |> Nx.reshape({:auto, num_trees})
        local_lefts = Nx.take(lefts, tree_nodes) |> Nx.reshape({:auto, num_trees})
        local_rights = Nx.take(rights, tree_nodes) |> Nx.reshape({:auto, num_trees})

        result =

            condition.(feature_values, local_thresholds),
          |> Nx.add(nodes_offset)
          |> Nx.reshape({:auto})

        {result, {features, lefts, rights, thresholds, nodes_offset, x}}

    |> Nx.take(indices)
    |> Nx.reshape({:auto, num_trees, n_classes})

Here there are some much more striking differences, namely the use of Nx's while expression compared to a for loop in Python. We use while in this case since it can achieve the same purpose as the Python for loop and it is supported by Nx within a defn expression. Otherwise, we might have to perform some of the calculations within a deftransform, as we will see in the next example. Another obvious difference is that in the Nx implementation, we have to pass the required variables around throughout these operation, whereas Python can use stored class attributes.

Still, the conversion is quite straightforward. I hope you are beginning to see that this is not an impossible effort, and can be accomplished given you have a firm understanding of the source material.

Perfect Tree Traversal

Lastly, let's look at the last conversion strategy. Yet again, this conversion is even slightly more complex, but hopefully seeing this example will help you in your case:

def forward(self, x):
        prev_indices = (self.decision_cond(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
        prev_indices = prev_indices + self.tree_indices
        prev_indices = prev_indices.view(-1)

        factor = 2
        for nodes, biases in zip(self.nodes, self.biases):
            gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
            features = torch.gather(x, 1, gather_indices).view(-1)
            prev_indices = (
                factor * prev_indices + self.decision_cond(features, torch.index_select(biases, 0, prev_indices)).long()

        output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)

And the Elixir implementation:

defnp _forward(
          opts \\ []
        ) do
    prev_indices =
      |> Nx.take(root_features, axis: 1)
      |> opts[:condition].(root_thresholds)
      |> Nx.add(indices)
      |> Nx.reshape({:auto})
      |> forward_reduce_features(x, features, thresholds, opts)

    Nx.take(values, prev_indices)
    |> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]})

  deftransformp forward_reduce_features(prev_indices, x, features, thresholds, opts \\ []) do
      fn nodes, biases, acc ->
        gather_indices = nodes |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]})
        features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto})

        |> Nx.multiply(@factor)
        |> Nx.add(opts[:condition].(features, Nx.take(biases, acc)))

You can see that in this case, we have a function defined in a deftransform within our forward pipeline. Why is this so? Well, when writing definitions within defn you forfeit the use of the default Elixir kernel for the Nx.Kernel module. If you want full access to all of the normal Elixir modules, you need to use a deftransform. We needed to use Enum.zip_reduce in this instance (rather than Nx's while like before) since the features and thresholds lists are not of uniform shape. Their shape represents the length of a given depth of a binary tree, so they will be a nested list of lengths [1,2,4,8...]. This is an optimization as opposed to normal TreeTraversal, but required a bit of a different approach as opposed to the Python implementation which took advantage of torch.nn.ParameterList to build out the same lists. You might also notice the use of Tuple.to_list on lines 25 and 26. This was required since we needed features and thresholds to be stored in Nx.container's when passed into the deftransform, and Tuple implements the Nx.Container protocol, while lists do not. Even still, given that knowledge of the intricacies of defn and deftransform, the final ported solution is very similar to the reference solution.


In this post, I tried to accomplish several things at once, and perhaps that lead to a cluttered article, but I felt the need to address all of these points at once. I do not mean to suggest that Machine Learning has no place in Python or that Python will not continue to be the most dominant player in Machine Learning, but that I think some healthy competition is a good thing, and that perhaps Python does have some shortcomings that might give other languages valid reasons to coexist in the space.

Next, I wanted to address some specifics as to what Elixir has to offer to the machine learning space. I think it is uniquely positioned to be quite competitive considering the large community push to support more and more libraries, as well as the large application development community that can benefit from an in-house solution.

Lastly, I wanted to share some practical tips for those looking to move on from Python to Elixir, but feeling somewhat helpless in the process. I think that Sean Moriarity's book that I mentioned at the beginning of this article is an invaluable resource and great step in the education of machine learning for Elixir developers, but it can nonetheless feel daunting to seemingly throw out existing working solutions for new-fangled, perhaps not as well respected solutions. I hope I showed how anybody can approach this problem, and any existing Elixir developer can be a machine learning developer going forward. The ground work has been laid, and the tools are available. Thank you for reading (especially if you made it to the end)!