Automatic differentiation and calculus using JAX
JAX is a linear algebra and automatic differentiation framework developed by Google for ML. It combines the capabilities of Autograd and its Accelerated Linear Algebra (XLA) optimizing compiler for linear algebra and ML. In particular, it allows us to easily construct complex functions, with automatic gradient computation, that can be run on Graphics Processing Units (GPUs) or Tensor Processing Units (TPUs). On top of all of this, it is relatively simple to use. In this recipe, we see how to make use of the JAX just-in-time (JIT) compiler, get the gradient of a function, and make use of different computation devices.
Getting ready
For this recipe, we need the JAX package installed. We will make use of the Matplotlib package, with the pyplot
interface imported as plt
as usual. Since we’re going to plot a function of two variables, we also need to import the mplot3d
module from the mpl_toolkits
package.