Topic: Implement JAX based automatic differentiation to Stingray
The project involved the study of modern statistical modelling to augment the accuracy, speed, and robustness of the likelihood function, into a software package called Stingray. This report demonstrates the experiment done for a combination of different optimizers to fit the scipy.optimize function. Another emphasis is to investigate the gradient calculation using JAX and compare it with scipy.optimize.
The proposed milestone was to investigate the room for improvement to enhance the overall performance of modelling to Stingray, using JAX. However, the current stage of the model is still a sandbox model. Stingray is astrophysical spectral timing software, a library in python built to perform time series analysis and related tasks on astronomical light curves. JAX is a python library designed for high-performance numerical computing. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. Both Python and NumPy are widely used and familiar, making JAX simple, flexible, and easy to adopt. It can differentiate through a large subset of python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives. Such modern differentiation packages deploy a broad range of computational techniques to improve applicability, run time, and memory management.
JAX utilizes the grad function transformation to convert a function into a function that returns the original function’s gradient, just like Autograd. Beyond that, JAX offers a function transformation jit for just-in-time compilation of existing functions and vmap and pmap for vectorization and parallelization, respectively.
The powerlaw and lorentzian function are the most used to describe periodograms in astronomy. In practice, we use the sum of these components to design a realistic model. For the analysis here we consider a quasi-periodic oscillation and a constant and try to fail the algorithm by, (i) reduce the amplitude, (ii) start the optimization process with parameters very far away from the true parameters, (iii) try different optimizers to experiment on different sensitive aspect of the current likelihood calculation. The current ongoing milestone is to try alternatives of scipy.optimize but this requires series of tests for the same.
The above tests can be visualized in the notebook added on Github: https://github.com/rashmiraj137/GSoC-Project
During the experiment, it was observed that the algorithm is sensitive to input parameters i.e., it fails for couple of combinations like when the amplitude is set far away from true value, precisely blow absolute value of 1. In general, if we set the parameters very far away from the true value, it fails to approximate the likelihood function. In the notebook (link), we demonstrate the room for improvement in the current algorithm by choosing a different set of parameters. The current data fit for the evaluation of likelihood happens using scipy.optimize.minimize function. However, there exists numerous ways to do this. SciPy optimize provides functions for minimizing (or maximizing) objective functions, possibly subject to constraints. It includes solvers for nonlinear problems (with support for both local and global optimization algorithms), linear programming, constrained and nonlinear least-squares, root finding, and curve fitting. The problem with the current minimization algorithm is that it converges at local minimum instead of global, i.e. it is not very robust. Recently, Machine Learning has evident development in such optimization tools. The strategy was to find alternatives that potentially accelerate the code, makes it robust.
For this experiment case, we choose a couple of optimizers and compare the robustness to Powell (the current optimizer used). So we visualize the fit for couple of optimizers like :
- minimize(method = “trust-constr”)
This notebook (link) has results using each method and it was observed that Nelder-Mead is more robust as compared to other optimizers. Another optimizer like dogleg, trust-ncg, might be good as well, but the jacobian and hess need to be calculated for them.
JAX has now its own version of scipy.optimize.minimize but it has couple of bugs and is not as robust as scipy.optimize.minimize. Finding an alternative for scipy.optimize.minimize such that it doesn’t fails even if the start parameters are far away from the true parameters was a goal for this project but unfortunately JAX did not assist that well enough. But there might be a superior algorithm to scipy.optimize.minimize that can be useful.
2. Insight of Implementation of JAX to stingray- GSoC coding period!
3. GSoC update!