Topic: Implement JAX based automatic differentiation to Stingray
The project involved the study of modern statistical modelling to augment the accuracy, speed, and robustness of the likelihood function, into a software package called Stingray. This report demonstrates the experiment done for a combination of different optimizers to fit the scipy.optimize function. Another emphasis is to investigate the gradient calculation using JAX and compare it with scipy.optimize.
Introduction:
The proposed milestone was to investigate the room for improvement to enhance the overall performance of modelling to Stingray, using JAX. However, the current stage of the model is still a sandbox model. Stingray is astrophysical spectral timing software, a library in python built to perform time series analysis and related tasks on astronomical light curves. JAX is a python library designed for high-performance numerical computing. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. Both Python and NumPy are widely used and familiar, making JAX simple, flexible, and easy to adopt. It can differentiate through a large subset of python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives. Such modern differentiation packages deploy a broad range of computational techniques to improve applicability, run time, and memory management.
JAX utilizes the grad function transformation to convert a function into a function that returns the original function’s gradient, just like Autograd. Beyond that, JAX offers a function transformation jit for just-in-time compilation of existing functions and vmap and pmap for vectorization and parallelization, respectively.
Read more…