JAX is more than just "NumPy for the GPU"—it offers advanced features but also presents unique challenges. This hands-on tutorial provides a practical introduction to JAX through interactive exercises covering key concepts such as:
jitcompilation for performance optimization- Native control flow using loop primitives
- Efficient function mapping with vmap
- Performance profiling techniques
jaxrandom number generation design and usage
Participants will then deepen their understanding by iteratively migrating a Gaussian Mixture Model from a pure numpy implementation to an optimized jax version, highlighting a real-world use-case.
This tutorial distills lessons the authors found invaluable during their own migration from numpy to jax, achieving over an order-of-magnitude speedup in real-world applications. Designed to provide attendees with a jumpstart on adopting jax, this session—along with its comprehensive set of notebooks—aims to be a one-stop resource for anyone looking to leverage jax for numerical computing and machine learning.