A crash course in machine learning
To fully appreciate the functionalities that JAX offers, let's first talk about the principal components of a typical workflow of training a machine learning model. If you are already familiar with the basics, feel free to skip to the next section, where we begin discussing JAX.
In machine learning, we set out to solve the problem of predicting an unknown target value of interest of a data point by considering its observable features. The goal is to design a predictive model that processes the observable features and outputs an estimate of what the target value might be. For example, image recognition models analyze the pixel values of an image to predict which object the image depicts, while a model processing weather data could predict the probability of rainy weather for tomorrow by accounting for temperature, wind, and humidity.
In general, a machine learning model could be viewed as a general mathematical function that takes in the...