How to speed up machine learning operations with Jax? – Analytics India Magazine

  • Lauren
  • March 19, 2022
  • Comments Off on How to speed up machine learning operations with Jax? – Analytics India Magazine

The machine learning algorithms require a lot of mathematical operations and as the performance of the model improves, its mathematical operations also increase with complexity. A simple example of this can be the random forest and decision tree where the random forest is more accurate in maximum cases but has complex mathematics and takes more time than the decision trees. Robust modelling requires a  process where large mathematical or numerical operations can be completed robustly. Jax is a library that can help us in improving the speed of mathematical operations. In this article, we will discuss the Jax library in detail. The major points to be discussed in the library are discussed below.

Table of content     

What is Jax?Jax vs NumPy Basic operationsMultiplying matricesUsing jit()Using Grad()

Let’s begin with understanding what Jax is.

What is Jax?

Jax is a library that can be considered as NumPy for CPU, TPU, and GPU. This can provide us with an efficient and automatic differentiation of the projects and research that are enabling machine learning in it. 

We can also consider this library as a library for numeric computation. The developer behind this library is Google and we can see the usage of this library in the machine learning projects of Google and Deep Mind. Using the below lines of code we can install this library in our environment.  

!pip install jax

Since I am using Google Colab to discuss Jax, it is already installed in the google collaboratory environment. We can also update Jax using the following lines of code.

!pip install –upgrade Jax jaxlib

After the installation, we are ready to use the Jax for performing numerical and mathematical operations.

Jax vs NumPy 

In the above introduction, we have discussed that Jax can be utilized for performing numerical operations and calculations. With this calculation, operations written using the Jax are expressible and high-performing. One of the most important things about the its modules use syntax that is similar to the NumPy, for example, the below codes.

import jax.numpy as jnp
arr = jnp.zeros(10)

It gives an array of ten zeros. When we use NumPy for this, we can write the following.

import numpy as np
arr = np.zeros(10)

The difference we can see in the defined array, where,






The above output means that the generated array from the Jax is a device array and the generated array using NumPy is a normal array. When we talk about the performances we can say that device arrays are lazy and they push the value when it is required, otherwise they keep the values in the accelerator.     

Here we can see a basic difference between NumPy and Jax-defined arrays. Talking about the detailed information, device arrays have the same function as the simple array which means they can be utilized for exerting functions where simple arrays can be utilized but the values come out only when it is required and this saves the machine from the irrelevant calculation or operations.  

One more thing where we can compare the Jax to NumPy is the speed. Let’s check this.

Defining array 

Using NumPy:

x = np.random.rand(2000,2000)

Using Jax:

y = jnp.array(x)

Let’s check the time taken by a machine to perform the inner product of arrays. The inner product of the Jax array,

%timeit -n 1 -r 1,y).block_until_ready()


The inner product of the NumPy array,

%timeit -n 1 -r 1,x)


Here we can see that the NumPy arrays have taken a larger time than the Jax array and also we use a function with Jax named block_until_ready which makes sure that the execution is completed so that the measurement of time can be proper. 

Basic operations 

In the above, we have seen the difference between NumPy and Jex where we have found that we can speed up the mathematical operation using the Jex library. Now below we will look at how some of the basic operations can be performed using Jax. 

Multiplying matrices 

Let’s define and multiply the matrices.

x = random.normal(random.PRNGKey(0), (3000,3000), dtype=jnp.float32)
%timeit y =, x.T).block_until_ready()


Here we can see the time it takes and the results of multiplication. Using the above codes we have defined a 3000 x 3000 matrix and multiplied it with its inverse. 

Let’s try different matrices for multiplication.

x = random.normal(random.PRNGKey(0), (3000,3000), dtype=jnp.float32)
y = random.normal(random.PRNGKey(0), (3000,2000), dtype=jnp.float32)
%timeit z =, y).block_until_ready()


Here we can see that the multiplication has enough speed, with the GPU it will work more efficiently. Instead of using a GPU, Jax provides some other methods, like jit can be used for speeding up the code. Let’s see an example of that.

Using jit()

Jit is a decorator that can help us in boosting the speed of the operation. In the above we can see that Jax is applied with the block_untill_ready method and in machine learning we find that operations are sequential and for such an operation we can use it. This can also be compiled with the XLA.  

Using Jax simply

def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) – alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()


Now, let’s use Jax with jit. 

from jax import jit
selu_jit = jit(selu)
%timeit selu(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()


Here we can see the difference between using Jax simply and using it with the jit which is a decorator of Jax.

Using Grad()

Machine learning or mathematical operations are not just about calculating the values, most of the time it may require us to transform the values then calculate and perform operations on them. Jax has autograds() in it. That can automatically differentiate native Numpy and Python code. Let’s compute the gradient using the grad function from autograd.

from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)


Here in the above, we can see the gradients of 0.25.

Jax also has many other features like vmap for factorization that can be utilized for a better and new experience of performing difficult mathematical operations very easily and effectively. 

Final words

In the article, we have discussed what Jax is and what makes it different from NumPy. We could also understand how we can perform mathematical operations speedily and efficiently. Along with this, we have seen some examples of operations that can improve our performance.