Getting JAX up and running
As briefly mentioned, JAX is a combination of different tools for developing accelerated, high-performance computations with a focus on machine learning applications. Remember from the last chapter that the NumPy library offers optimized computation for numerical operations such as finding the min/max or taking the sum of the average along an axis. We can think of JAX as the NumPy equivalent for machine learning, where common tasks in machine learning could be done in highly optimized code. These, as we will see, include automatic differentiation, accelerated linear algebra using a Just-In-Time compiler, and efficient vectorization and parallelization of code, among other things.
JAX offers these functionalities through what's known as functional transformations. In the simplest sense, a functional transformation in JAX converts a function, typically one that we build ourselves, to an optimized version where different functionalities are facilitated...