I get confused with tensor computation libraries (or computational graph libraries, or symbolic algebra libraries, or whatever they’re marketing themselves as these days).

I was first introduced to PyTorch and TensorFlow and, having no other reference, thought they were prototypical examples of tensor computation libraries. Then I learnt about Theano - an older and less popular project, but different from PyTorch and TensorFlow and better in some meaningful ways. This was followed by JAX, which seemed to be basically NumPy with more bells and whistles (although I couldn’t articulate what exactly they were). Then came the announcement by the PyMC developers that Theano would have a new JAX backend.

Anyways, this confusion prompted a lot of research and eventually, this blog post.

Similar to my previous post on the anatomy of probabilistic programming frameworks, I’ll first discuss tensor computation libraries in general - what they are and how they can differ from one another. Then I’ll discuss some libraries in detail, and finally offer an observation on the future of Theano in the context of contemporary tensor computation libraries.

Dissecting Tensor Computation Libraries

First, a characterization: what do tensor computation libraries even do?

  1. They provide ways of specifying and building computational graphs,
  2. They run the computation itself (duh), but also run “related” computations that either (a) use the computational graph, or (b) operate directly on the computational graph itself,
    • The most salient example of the former is computing gradients via autodifferentiation,
    • A good example of the latter is optimizing the computation itself: think symbolic simplifications (e.g. xy/x = y) or modifications for numerical stability (e.g. log(1 + x) for small values of x).
  3. And they provide “best execution” for the computation: whether it’s changing the execution by JIT (just-in-time) compiling it, by utilizing special hardware (GPUs/TPUs), by vectorizing the computation, or in any other way.

“Tensor Computation Library” - Maybe Not The Best Name

As an aside: I realize that the name “tensor computation library” is too broad, and that the characterization above precludes some libraries that might also justifiably be called “tensor computation libraries”. Better names might be “graph computation library” (although that might get mixed up with libraries like networkx) or “computational graph management library” or even “symbolic tensor algebra libraries”.

So for the avoidance of doubt, here is a list of libraries that this blog post is not about:

  • NumPy and SciPy
    • These libraries don’t have a concept of a computational graph - they’re more like a toolbox of functions, called from Python and executed in C or Fortran.
    • However, this might be a controversial distinction - as we’ll see later, JAX also doesn’t build an explicit computational graph either, and I definitely want to include JAX as a “tensor computation library”… ¯\_(ツ)_/¯
  • Numba and Cython
    • These libraries provide best execution for code (and in fact some tensor computation libraries, such as Theano, make good use them), but like NumPy and SciPy, they do not actually manage the computational graph itself.
  • Keras, Trax, Flax and PyTorch-Lightning
    • These libraries are high-level wrappers around tensor computation libraries - they basically provide abstractions and a user-facing API to utilize tensor computation libraries in a friendlier way.

(Some) Differences Between Tensor Computation Libraries

Anyways, back to tensor computation libraries.

All three aforementioned goals are ambitious undertakings with sophisticated solutions, so it shouldn’t be surprising to learn that decisions in pursuit on goal can have implications for (or even incur a trade-off with!) other goals. Here’s a list of common differences along all three axes:

  1. Tensor computation libraries can differ in how they represent the computational graph, and how it is built.
    • Static or dynamic graphs: do we first define the graph completely and then inject data to run (a.k.a. define-and-run), or is the graph defined on-the-fly via the actual forward computation (a.k.a. define-by-run)?
      • TensorFlow 1.x was (in)famous for its static graphs, which made users feel like they were “working with their computational graph through a keyhole”, especially when compared to PyTorch’s dynamic graphs.
    • Lazy or eager execution: do we evaluate variables as soon as they are defined, or only when a dependent variable is evaluated? Usually, tensor computation libraries either choose to support dynamic graphs with eager execution, or static graphs with lazy execution - for example, TensorFlow 2.0 supports both modes.
    • Interestingly, some tensor computation libraries (e.g. Thinc) don’t even construct an explicit computational graph: they represent it as chained higher-order functions.
  2. Tensor computation libraries can also differ in what they want to use the computational graph for - for example, are we aiming to do things that basically amount to running the computational graph in a “different mode”, or are we aiming to modify the computational graph itself?
  3. Finally, tensor computation libraries can also differ in how they execute code.
    • All tensor computation libraries run on CPU, but the strength of GPU and TPU support is a major differentiator among tensor computation libraries.
    • Another differentiator is how tensor computation libraries compile code to be executed on hardware. For example, do they use JIT compilation or not? Do they use “vanilla” C or CUDA compilers, or the XLA compiler for machine-learning specific code?

A Zoo of Tensor Computation Libraries

Having outlined the basic similarities and differences of tensor computation libraries, I think it’ll be helpful to go through several of the popular libraries as examples. I’ve tried to link to the relevant documentation where possible.1

PyTorch

  1. How is the computational graph represented and built?
  2. What is the computational graph used for?
    • To quote the PyTorch docs, “PyTorch is an optimized tensor library for deep learning using GPUs and CPUs” - as such, the main focus is on autodifferentiation.
  3. How does the library ensure “best execution” for computation?

JAX

  1. How is the computational graph represented and built?
    • Instead of building an explicit computational graph to compute gradients, JAX simply supplies a grad() that returns the gradient function of any supplied function. As such, there is technically no concept of a computational graph - only pure (i.e. stateless and side-effect-free) functions and their gradients.
    • Sabrina Mielke summarizes the situation very well:

      PyTorch builds up a graph as you compute the forward pass, and one call to backward() on some “result” node then augments each intermediate node in the graph with the gradient of the result node with respect to that intermediate node. JAX on the other hand makes you express your computation as a Python function, and by transforming it with grad() gives you a gradient function that you can evaluate like your computation function — but instead of the output it gives you the gradient of the output with respect to (by default) the first parameter that your function took as input.

  2. What is the computational graph used for?
    • According to the JAX quickstart, JAX bills itself as “NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research”. Hence, its focus is heavily on autodifferentiation.
  3. How does the library ensure “best execution” for computation?
    • This is best explained by quoting the JAX quickstart:

      JAX uses XLA to compile and run your NumPy code on […] GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels […] Compilation and automatic differentiation can be composed arbitrarily […]

    • For more detail on JAX’s four-function API (grad, jit, vmap and pmap), see Alex Minaar’s overview of how JAX works.

Theano

Note: the original Theano (maintained by MILA) has been discontinued, and the PyMC developers have forked the project: Theano-PyMC (soon to be renamed Aesara). I’ll discuss both the original and forked projects below.

  1. How is the computational graph represented and built?
    • Theano statically builds (and lazily evaluates) an explicit computational graph.
  2. What is the computational graph used for?
    • Theano is unique among tensor computation libraries in that it places more emphasis on reasoning about the computational graph itself. In other words, while Theano has strong support for autodifferentiation, running the computation and computing gradients isn’t the be-all and end-all: Theano has an entire module for optimizing the computational graph itself, and makes it fairly straightforward to compile the Theano graph to different computational backends (by default, Theano compiles to C or CUDA, but it’s straightforward to compile to JAX).
    • Theano is often remembered as a library for deep learning research, but it’s so much more than that!
  3. How does the library ensure “best execution” for computation?
    • The original Theano used the GCC C compiler for CPU computation, and the NVCC CUDA compiler for GPU computation.
    • The Theano-PyMC fork project will use JAX as a backend, which can utilize CPUs, GPUs and TPUs as available.

An Observation on Static Graphs and Theano

Finally, a quick observation on static graphs and the niche that Theano fills that other tensor computation libraries do not. I had huge help from Thomas Wiecki and Brandon Willard with this section.

There’s been a consistent movement in most tensor computation libraries away from static graphs (or more precisely, statically built graphs): PyTorch and TensorFlow 2 both support dynamically generated graphs by default, and JAX forgoes an explicit computational graph entirely.

This movement is understandable - building the computational graph dynamically matches people’s programming intuition much better. When I write z = x + y, I don’t mean “I want to register a sum operation with two inputs, which is waiting for data to be injected” - I mean “I want to compute the sum of x and y“. The extra layer of indirection is not helpful to most users, who just want to run their tensor computation at some reasonable speed.

So let me speak in defence of statically built graphs.

Having an explicit representation of the computational graph is immensely useful for certain things, even if it makes the graph harder to work with. You can modify the graph (e.g. graph optimizations, simplifications and rewriting), and you can reason about and analyze the graph. Having the computation as an actual object helps immeasurably for tasks where you need to think about the computation itself, instead of just blindly running it.

On the other hand, with dynamically generated graphs, the computational graph is never actually defined anywhere: the computation is traced out on the fly and behind the scene. You can no longer do anything interesting with the computational graph: for example, if the computation is slow, you can’t reason about what parts of the graph are slow. The end result is that you basically have to hope that the framework internals are doing the right things, which they might not!

This is the niche that Theano (or rather, Theano-PyMC/Aesara) fills that other contemporary tensor computation libraries do not: the promise is that if you take the time to specify your computation up front and all at once, Theano can optimize the living daylight out of your computation - whether by graph manipulation, efficient compilation or something else entirely - and that this is something you would only need to do once.


  1. Some readers will notice the conspicuous lack of TensorFlow from this list - its exclusion isn’t out of malice, merely a lack of time and effort to do the necessary research to do it justice. Sorry.