JAX: Accelerated Machine Learning Research via Composable Function Transformations in Python

Опубликовано: 28 Октябрь 2024
на канале: Anyscale
323
7

(Matt Johnson, Google Brain)

This talk is about JAX, a system for high-performance machine learning research and numerical computing. JAX offers the familiarity of Python+NumPy together with hardware acceleration. JAX combines these features with user-wielded function transformations, including automatic differentiation, automatic vectorized batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity. It’s used by researchers for a wide range of advanced applications, from large-scale neural net training, to probabilistic programming, to scientific applications in physics and biology.