Automatic vectorization for efficient kernels
You might remember from our discussions on NumPy that the library is efficient at applying numerical operations to all elements in an array or the elements along specific axes. By exploiting the fact that the same operation is to be applied to multiple elements, the library optimizes low-level code that performs the operation, making the computation much more efficient than doing the same thing via an iterative loop. This process is called vectorization.
When working with machine learning models, we would like to go through a procedure of vectorizing a specific function, rather than looping through an array or a matrix, to gain performance speedup. Vectorization is typically not easy to do and might involve clever tricks to rewrite the function that we'd like to vectorize into another form that admits vectorization easily.
JAX addresses this concern by providing a function transformation that automatically vectorizes a given...