Just-In-Time compilation for improved efficiency
As we have learned from the last chapter, JIT compilation allows a piece of code that is expected to run many times to be executed more efficiently. This process is specifically useful in machine learning where functions such as the loss or the gradient of the loss of a model need to be computed many times during the loss minimization phase. We hence expect that by leveraging a JIT compiler, we can make our machine learning models train faster.
You might think that to do this, we would need to hook one of the JIT compilers we considered in the last chapter into JAX. However, JAX comes with its own JIT compiler, which requires minimal code to integrate in an existing program. We will see how to use it by modifying the training loop we made in the last section.
First, we reset the parameters of our models:
np.random.seed(0) w = np.random.randn(3)
Now, the way we will integrate the JIT compiler into our program is to point...