What I Wish Someone Had Told Me About Tensor Computation Libraries
I get confused with tensor computation libraries (or computational graph libraries, or symbolic algebra libraries, or whatever they’re marketing themselves as these days).
I was first introduced to PyTorch and TensorFlow and, having no other reference, thought they were prototypical examples of tensor computation libraries. Then I learnt about Theano — an older and less popular project, but different from PyTorch and TensorFlow and better in some meaningful ways. This was followed by JAX, which seemed to be basically NumPy with more bells and whistles (although I couldn’t articulate what exactly they were). Then came the announcement by the PyMC developers that Theano would have a new JAX backend.
Anyways, this confusion prompted a lot of research and eventually, this blog post.
Similar to my previous post on the anatomy of probabilistic programming frameworks, I’ll first discuss tensor computation libraries in general — what they are and how they can differ from one another. Then I’ll discuss some libraries in detail, and finally offer an observation on the future of Theano in the context of contemporary tensor computation libraries.
Contents
Dissecting Tensor Computation Libraries
First, a characterization: what do tensor computation libraries even do?
- They provide ways of specifying and building computational graphs,
- They run the computation itself (duh), but also run “related” computations that either (a) use
the computational graph, or (b) operate directly on the computational graph itself,
- The most salient example of the former is computing gradients via autodifferentiation,
- A good example of the latter is optimizing the computation itself: think symbolic
simplifications (e.g.
xy/x = y
) or modifications for numerical stability (e.g.log(1 + x)
for small values ofx
).
- And they provide “best execution” for the computation: whether it’s changing the execution by JIT (just-in-time) compiling it, by utilizing special hardware (GPUs/TPUs), by vectorizing the computation, or in any other way.
“Tensor Computation Library” — Maybe Not The Best Name
As an aside: I realize that the name “tensor computation library” is too broad, and that the
characterization above precludes some libraries that might also justifiably be called “tensor
computation libraries”. Better names might be “graph computation library” (although that might get
mixed up with libraries like networkx
) or “computational graph management
library” or even “symbolic tensor algebra libraries”.
So for the avoidance of doubt, here is a list of libraries that this blog post is not about:
- NumPy and SciPy
- These libraries don’t have a concept of a computational graph — they’re more like a toolbox of functions, called from Python and executed in C or Fortran.
- However, this might be a controversial distinction — as we’ll see later, JAX also doesn’t build an explicit computational graph either, and I definitely want to include JAX as a “tensor computation library”… ¯\_(ツ)_/¯
- Numba and Cython
- These libraries provide best execution for code (and in fact some tensor computation libraries, such as Theano, make good use them), but like NumPy and SciPy, they do not actually manage the computational graph itself.
- Keras, Trax, Flax and PyTorch-Lightning
- These libraries are high-level wrappers around tensor computation libraries — they basically provide abstractions and a user-facing API to utilize tensor computation libraries in a friendlier way.
(Some) Differences Between Tensor Computation Libraries
Anyways, back to tensor computation libraries.
All three aforementioned goals are ambitious undertakings with sophisticated solutions, so it shouldn’t be surprising to learn that decisions in pursuit on goal can have implications for (or even incur a trade-off with!) other goals. Here’s a list of common differences along all three axes:
Tensor computation libraries can differ in how they represent the computational graph, and how it is built.
- Static or dynamic graphs: do we first define the graph completely and then inject data to run
(a.k.a. define-and-run), or is the graph defined on-the-fly via the actual forward computation
(a.k.a. define-by-run)?
- TensorFlow 1.x was (in)famous for its static graphs, which made users feel like they were “working with their computational graph through a keyhole”, especially when compared to PyTorch’s dynamic graphs.
- Lazy or eager execution: do we evaluate variables as soon as they are defined, or only when a dependent variable is evaluated? Usually, tensor computation libraries either choose to support dynamic graphs with eager execution, or static graphs with lazy execution — for example, TensorFlow 2.0 supports both modes.
- Interestingly, some tensor computation libraries (e.g. Thinc) don’t even construct an explicit computational graph: they represent it as chained higher-order functions.
- Static or dynamic graphs: do we first define the graph completely and then inject data to run
(a.k.a. define-and-run), or is the graph defined on-the-fly via the actual forward computation
(a.k.a. define-by-run)?
Tensor computation libraries can also differ in what they want to use the computational graph for — for example, are we aiming to do things that basically amount to running the computational graph in a “different mode”, or are we aiming to modify the computational graph itself?
- Almost all tensor computation libraries support autodifferentiation in some capacity (either forward-mode, backward-mode, or both).
- Obviously, how you represent the computational graph and what you want to use it for are very
related questions! For example, if you want to be able to represent aribtrary computation as a
graph, you’ll have to handle control flow like if-else statements or for-loops — this leads
to common gotchas with using Python for-loops in
JAX
or needing to use
torch.nn.ModuleList
in for-loops with PyTorch. - Some tensor computation libraries (e.g. Theano and its fork, Theano-PyMC) aim to optimize the computational graph itself, for which an explicit graph is necessary.
Finally, tensor computation libraries can also differ in how they execute code.
- All tensor computation libraries run on CPU, but the strength of GPU and TPU support is a major differentiator among tensor computation libraries.
- Another differentiator is how tensor computation libraries compile code to be executed on hardware. For example, do they use JIT compilation or not? Do they use “vanilla” C or CUDA compilers, or the XLA compiler for machine-learning specific code?
A Zoo of Tensor Computation Libraries
Having outlined the basic similarities and differences of tensor computation libraries, I think it’ll be helpful to go through several of the popular libraries as examples. I’ve tried to link to the relevant documentation where possible.1
PyTorch
- How is the computational graph represented and built?
- PyTorch dynamically builds (and eagerly evaluates) an explicit computational graph. For more detail on how this is done, check out the PyTorch docs on autograd mechanics.
- For more on how PyTorch computational graphs, see
jdhao
’s introductory blog post on computational graphs in PyTorch.
- What is the computational graph used for?
- To quote the PyTorch docs, “PyTorch is an optimized tensor library for deep learning using GPUs and CPUs” — as such, the main focus is on autodifferentiation.
- How does the library ensure “best execution” for computation?
- PyTorch has native GPU support via CUDA.
- PyTorch also has support for TPU through projects like PyTorch/XLA and PyTorch-Lightning.
JAX
How is the computational graph represented and built?
Instead of building an explicit computational graph to compute gradients, JAX simply supplies a
grad()
that returns the gradient function of any supplied function. As such, there is technically no concept of a computational graph — only pure (i.e. stateless and side-effect-free) functions and their gradients.Sabrina Mielke summarizes the situation very well:
PyTorch builds up a graph as you compute the forward pass, and one call to
backward()
on some “result” node then augments each intermediate node in the graph with the gradient of the result node with respect to that intermediate node. JAX on the other hand makes you express your computation as a Python function, and by transforming it withgrad()
gives you a gradient function that you can evaluate like your computation function — but instead of the output it gives you the gradient of the output with respect to (by default) the first parameter that your function took as input.
What is the computational graph used for?
- According to the JAX quickstart, JAX bills itself as “NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research”. Hence, its focus is heavily on autodifferentiation.
How does the library ensure “best execution” for computation?
This is best explained by quoting the JAX quickstart:
JAX uses XLA to compile and run your NumPy code on […] GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels […] Compilation and automatic differentiation can be composed arbitrarily […]
For more detail on JAX’s four-function API (
grad
,jit
,vmap
andpmap
), see Alex Minaar’s overview of how JAX works.
Theano
Note: the original Theano (maintained by MILA) has been discontinued, and the PyMC developers have forked the project: Theano-PyMC (soon to be renamed Aesara). I’ll discuss both the original and forked projects below.
- How is the computational graph represented and built?
- Theano statically builds (and lazily evaluates) an explicit computational graph.
- What is the computational graph used for?
- Theano is unique among tensor computation libraries in that it places more emphasis on reasoning about the computational graph itself. In other words, while Theano has strong support for autodifferentiation, running the computation and computing gradients isn’t the be-all and end-all: Theano has an entire module for optimizing the computational graph itself, and makes it fairly straightforward to compile the Theano graph to different computational backends (by default, Theano compiles to C or CUDA, but it’s straightforward to compile to JAX).
- Theano is often remembered as a library for deep learning research, but it’s so much more than that!
- How does the library ensure “best execution” for computation?
- The original Theano used the GCC C compiler for CPU computation, and the NVCC CUDA compiler for GPU computation.
- The Theano-PyMC fork project will use JAX as a backend, which can utilize CPUs, GPUs and TPUs as available.
An Observation on Static Graphs and Theano
Finally, a quick observation on static graphs and the niche that Theano fills that other tensor computation libraries do not. I had huge help from Thomas Wiecki and Brandon Willard with this section.
There’s been a consistent movement in most tensor computation libraries away from static graphs (or more precisely, statically built graphs): PyTorch and TensorFlow 2 both support dynamically generated graphs by default, and JAX forgoes an explicit computational graph entirely.
This movement is understandable — building the computational graph dynamically matches people’s
programming intuition much better. When I write z = x + y
, I don’t mean “I want to register a sum
operation with two inputs, which is waiting for data to be injected” — I mean “I want to compute
the sum of x
and y
”. The extra layer of indirection is not helpful to most users, who just want
to run their tensor computation at some reasonable speed.
So let me speak in defence of statically built graphs.
Having an explicit representation of the computational graph is immensely useful for certain things, even if it makes the graph harder to work with. You can modify the graph (e.g. graph optimizations, simplifications and rewriting), and you can reason about and analyze the graph. Having the computation as an actual object helps immeasurably for tasks where you need to think about the computation itself, instead of just blindly running it.
On the other hand, with dynamically generated graphs, the computational graph is never actually defined anywhere: the computation is traced out on the fly and behind the scene. You can no longer do anything interesting with the computational graph: for example, if the computation is slow, you can’t reason about what parts of the graph are slow. The end result is that you basically have to hope that the framework internals are doing the right things, which they might not!
This is the niche that Theano (or rather, Theano-PyMC/Aesara) fills that other contemporary tensor computation libraries do not: the promise is that if you take the time to specify your computation up front and all at once, Theano can optimize the living daylight out of your computation — whether by graph manipulation, efficient compilation or something else entirely — and that this is something you would only need to do once.
Some Follow-Ups, A Week Later
2020-12-22
The blog post trended on Hacker News and got some discussion. It’s stupefying how the most upvoted comments are either unrelated or self-promotional, but I suppose that’s to be expected with the Internet.
However, one nugget of gold in the junk pit is this comment by Albert Zeyer and the response by the PyMC developer spearheading the Aesara project, Brandon Willard. I had two takeaways from this exchange:
- Theano is messy, either in a code hygiene sense, or in an API design sense.
- For example, the graph optimization/rewriting process can require entire graphs to be copied at multiple points along the way. This obliterates performance and was almost entirely due to some design oddities.
- The JAX backend arose as a proof-of-concept of how extensible Theano is, both in terms of “hackability” and how much mileage we can get out of the design choices behind Theano (e.g. static graphs). The JAX backend isn’t the focus of the fork, but it’s easily the difference that will stand out most at the user level. The focus of the Aesara is resolving the design shortcomings of Theano.
On the one hand, I’m glad that I finally understand the real focus of the Aesara fork — I feel like I have a much greater appreciation of what Aesara really is, and it’s place in the ecosystem of tensor computation libraries.
On the other hand, I’m discomfited by the implication that meaningful contributions to Aesara must involve deep expertise on computational graphs and graph optimizations - neither of which I have experience in (and I suspect are rare even among the open source community). Moreover, meaningful contributions to Aesara will probably require deep familiarity with Theano’s design and its shortcomings. This isn’t to discourage me (or anyone else!) from contributing to Aesara, but it’s good to acknowledge the bottomless pit of technical expertise that goes on behind the user-facing Bayesian modelling.
Some readers will notice the conspicuous lack of TensorFlow from this list - its exclusion isn’t out of malice, merely a lack of time and effort to do the necessary research to do it justice. Sorry. ↩︎