From TensorFlow and PyTorch to Jax and Julia: Engineering Trade-Offs in Automatic Differentiation
In order to understand the differences between automatic differentiation libraries, let’s talk about the trade-offs that were made by the people who made them. I would say that none of these libraries are better than the other. They all make engineering decisions based on the domains and use cases they were meant to serve. If you look at the evolution of these trade-offs, you can see how each new library changed them.
Early TensorFlow used a graph building system, which meant that users had to write their variables in a separate graph language from the host language. TensorFlow variables and operations had to be set up first. The AD would then be done on this graph, which was already set up. A limited number of control flow constructs were used because they could only be shown statically. An ‘ifelse’ function statement is very different from a conditional “if” then “else” of code because “ifelse” would be the same as always calling both branches and then choosing the result, so there would only be one code path (though I say semantically because further compiler optimizations may and usually do reduce that). This static sublanguage is then put into an intermediate representation (IR) called XLA, which did a lot of work to make linear algebra easier. AD was done with simple graph representation algorithms because there was no real control flow at this representation. If you want to be more efficient, you can use XLA because it can see everything in the world. But this has a lot of downsides when it comes to flexibility and convenience.
Thus, you can think of this as a source code transformation because all of the autodiff is done on an IR for a language that isn’t the same as the host language. For the most part, the user had to do the translation to the new language for the AD system, which was… a little inconvenient.
When PyTorch came along, it came up with a way to solve the flexibility and convenience issues by using tape instead. You run a forward pass, and then the code for autodiff is made. It does this by recording the operations that it sees in each forward pass, and then comparing them in reverse. When you do this “building of the tape,” the operator overloads as part of the Tensor type that PyTorch says you need to use, so the tape is made. How it works is very clear. f(2.0) would take the first bench in an if statement and run the while loop 5 times. In this case, the AD pass would then take that set of operations and run backpropagation through five loops and back through the first branch. Notice that when you use this form, the AD doesn’t see any dynamic control flow: that was all done in Python, but not on the tape itself. Thus, the AD doesn’t have to deal with dynamic control flow. This makes it very easy to deal with a lot of weird language situations. AD is “per value,” which means that you can’t do a lot of optimizations on the backwards passes because you won’t see the same backwards pass again. This means that you can do a lot less optimization.
There is no way PyTorch’s efficiency can be restored after this. Yes, and no. No, it does not. Most machine learning algorithms are so reliant on expensive kernels, like matrix multiplication (‘A*x’), ‘conv’, and so on, that the amount of work per operation is so high in most ML applications that it hides the overhead of this method. A lot of time can be spent on improving the 2,000 or so operators that PyTorch has. Most people in ML think PyTorch is fast because it has fast kernels (fast conv calls, for example) even though it has a lot of AD overhead. It’s also possible to run into situations where the AD and Python interpreter overhead is not washed out. where your arrays are small or where a lot of scalar operations are taking place, like in the Julia vs. PyTorch Neural ODE benchmarks on scientific model discovery workflows. You see a 100x speed boost in Julia, even when you don’t use AD for the ODE and SDE solvers. This is mostly due to the language and AD overhead because the kernels used in these cases are small. So the PyTorch team has been working on things like ‘torch.@jit’ as a separate language that can be compiled and optimized in a way that’s better for these kinds of problems. There’s been a lot of debate about whether that’s a good idea in the long run. It doesn’t matter, though. PyTorch has done well because it made good choices for its use case.
So then TensorFlow Eager (2.0) comes along and adds dynamic control flow support in a way that looks a lot like PyTorch in a desperate attempt to get everyone back. Unfortunately, it doesn’t play well with all of the XLA tooling because it can’t see the whole graph of all possible operations for all input values to optimize it well, so it didn’t hit the TensorFlow speeds everyone was expecting.
They then came up with ways to either expand the range of these ideas or try to combine some of the best parts of both sides. It’s Jax. Jax interprets the code in a way that isn’t standard. Then, it makes a copy of the code in its own IR and runs AD on it, before lowering it to TensorFlow’s XLA for optimizations. This is how Jax thinks about operator overloading. It has special objects that walk through code to build out the exprs (this is called the “tracing” step). What if there’s dynamic control flow? How can it trace the full code if PyTorch only sees parts of the full code’s possible paths, like it did with Python? Because it wants to keep the code from having real dynamic behavior at trace time, it doesn’t want you to use full dynamic control flow. Instead, it wants you to use Jax primitives like lax.while, which are function calls that can be caught during tracing. For TensorFlow, you get a graph builder that looks and feels more natural. Because it all ends up in XLA, you get the same efficiency there, but it looks and feels a lot more natural. In Jax: The Sharp Bits, the author talks about how the linked primitives exist, and why they aren’t well optimized. That’s why they exist. The problem is that “most” ML algorithms aren’t very dynamic. Recurrent neural networks know how many layers they have and don’t have a loop that iterates to tolerance. This means that “most” algorithms do well in this sublanguage because they aren’t very dynamic. When it comes to that, it can naturally make a lot of codes run faster.
What if we kept the AD dynamic?
Is it possible to keep the full dynamic of the language in the AD system? Possible, but it’s hard. This is what a lot of the Julia AD tools have tried to do with source code changes (along with Swift for TensorFlow). This may be hard to do because source code is written for people. Because these tools don’t work with high-level IR, they work with low-level representations. These low-level representations cut out a lot of the “cruft” of syntax to make a much smaller surface area for the tools. Zygote.jl saw that by acting on the SSA IR, it could directly support control flow like while loops without having to unroll them into sets of operations (like PyTorch or TensorFlow Eager) or only supporting a sublanguage of control flow. This was the heart of its approach (like Jax). When you write source code, things like while loops and other dynamic constructs are turned into static (source code) representations that include new lines of code for things like stacks that store information about the forward pass (like which branch was taken). These stacks are then used and accessed in the backwards pass, which is written after the forward pass. In this case, the code that’s written isn’t dynamic, but the code that’s written backwards is dynamic (because it uses the stack to tell it how many times to walk the for loop). These forms of AI don’t have a single code for all branches, so AD can be more like TensorFlow in a world where the dynamic control flow isn’t gone.
There’s no reason not to use this because it sounds like the best of both worlds. There are two parts to that. Accepting that your AD will have to deal with the whole dynamic nature of a whole programming language means that your job will be much more difficult than it would be if you didn’t. The whole point of the AD approaches in TensorFlow, PyTorch, and Jax is to get rid of these things before the AD, so they have a much smaller surface area of language support. It’s hard to use Python because it’s so crazy in terms of what it can do with dynamism, so people who were working on these solutions used languages that were easy for compilers to work with, like Julia and Swift. This is because Python is so hard to work with. Python has most of the people who do ML, which makes it hard for new people to get in.
There is still a lot of work to be done, though. In Julia, it was found that Zygote works on too high an IR, which means that it doesn’t work until compiler optimizations have been made. This means that you have to do AD on unoptimized code only to delete most of the work later, and so it would be better for it to go even lower. Because of this, the Diffractor.jl project was born. Because some optimization only happens at the LLVM level, Julia developers started building an AD system that works directly on the LLVM’s IR, called Enzyme. This project was led by members of the Julia Lab, but because it works at the LLVM level, it can be used with any LLVM-compiling language, like C/C++ (Clang) or Rust. There is then a trade-off with source code transform methods as you get lower and lower in the IRs, which I’ll talk about in a separate post. There is: ENZYME is an enzyme that can work after compiler optimizations. This means that a lot of the higher level information might be removed (at least, until dialects like MLIR are ready). It may not be able to do all of the linear algebra simplifications that XLA does because some of the function calls may have been inlined and deleted. For example, XLA will combine many matrix-vector multiplications into one matrix-matrix multiplication because some of the function calls may have been inlined and deleted. History shows that this is hard, but not impossible, to do. This remining loopy code needs to be optimized to run at BLAS speeds. Because some functions to a nonlinear solver might not have been deleted, optimized adjoints that outperform direct differentiation of code, like Deep Equilibrium Models (DEQ), might be less optimized. The lowest level, on the other hand, is very good at scalar code differentiation and mutation. Because Diffractor uses Julia’s typed IIR, it can apply higher level rules quickly and consistently. In theory, it could do transformations like XLA, which is why it’s called Diffractor (i.e. keeping BLAS calls intact and fusing them). But writing these kinds of analyses on a fully dynamic compute IR is so hard that it hasn’t been done yet. A lot of tools are being built to help with things like this, but the fact is that it’s more work to do it on a language IR than a sublanguage graph like XLA, because you have to figure out how to escape and how to move shapes. The compiler could theoretically prove that a function is semi-static in the sense of XLA and get the same optimizations as Jax or TensorFlow, but that doesn’t happen today and it’s not easy to do, so that’s not what happens. Enzyme and Diffractor approaches are likely to be used in Julia AD systems in the future. The trade-off is that generality comes at the cost of more complexity.
The second thing to keep in mind is that most ML codes don’t actually use that much dynamism, which is a bad thing. There are simple forms of dynamism in all of these things, like recurrent neural networks, transformers, convolutions, and so on. In some ways, this dynamism is very stable. That’s a trade-off that most people don’t think about: why help your users when they don’t need it? The number of layers you have does not depend on the values that come out of the layers, so you don’t have to worry about that. In this case, the main reason why dynamic ML workflows are supported is because it is convenient, not because they are required. Algorithms that have a lot of flexibility are usually easier to write in terms of operations in the language. For example, you can write a function and write the adjoint derivative for that function. Even though adaptive ODE solvers need to know the calculated values in order to figure out how many steps to take, Jax supports ODEs. You can’t use Jax to divide an ODE code, but if you use an ODE solver with a defined adjoint, you can. Jax doesn’t support some algorithms, and it can be hard to make performance/stability trade-offs in some cases because adjoint definitions are very different from solver definitions. This is why most people haven’t looked into these issues, because they aren’t very common in most ML use cases, which is why most people haven’t looked into this.
There we are today. There will be more ML algorithms in the future that need to be able to deal with more dynamic structures. Is it going to be important for people who use AD systems to make their code run faster? In scientific machine learning (SciML), I know this story very well because it’s the answer to my field. Mutation is used by climate models because changing huge buffers would have a big impact on performance. A fact of life is that you can’t just use simple adjoints in PyTorch and Jax to solve stiff equations. In these cases, the gradients just go to Inf. Whether this type of machine learning is common in the future is still up in the air, but hopefully this shows how all of the choices made here were not “better” or “worse,” but instead were made in the context of that specific field.