Chapter 6: Automatic Differentiation and Accelerated Linear Algebra for Machine Learning
With the recent explosion of data and data generating systems, machine learning has grown to be an exciting field, both in research and industry. However, implementing a machine learning model might prove to be a difficult endeavor. Specifically, common tasks in machine learning, such as deriving the loss function and its derivative, using gradient descent to find the optimal combination of model parameters, or using the kernel method for nonlinear data, demand clever implementations to make predictive models efficient.
In this chapter, we will discuss the JAX library, the premier high-performance machine learning tool in Python. We will explore some of its most powerful features, such as automatic differentiation, JIT compilation, and automatic vectorization. These features streamline the tasks that are central to machine learning mentioned previously, making training a predictive model as...