PyData Vermont 2024

The state of Bayesian workflows in JAX
07-30, 14:00–14:30 (US/Eastern), Filmhouse

This talk provides an overview to Bayesian workflows in the JAX ecosystem, and is aimed at both newcomers and experienced practitioners. We will look at the Bayesian workflow, from model specification and prior selection, to choices for inference like MCMC, optimization, and VI, to posterior analysis and validation. We will discuss practical ways to use popular libraries in building modular, efficient workflows.


This talk is aimed at both newcomers to Bayesian methods and JAX, as well as experienced data scientists, researchers, and other practitioners. Familiarity with some math and numpy-like programming is recommended.

  • Introduction (5 minutes): A worked example giving an overview of the Bayesian workflow, emphasizing the interoperability of libraries, and some JAX transformations. We elaborate on this example for the rest of the talk.
  • Model specification and priors (5 minutes): We implement a linear model in pure JAX, then show the same in PyMC, NumPyro, and TFP. We show how a modelling language may help with iterating on the model.
  • Inference methods (10 minutes): Introduce the library bayeux for moving from any of our model specifications to running inference. We will talk about optimization, and show methods like Adam and (L)BFGS from libraries like optax, jaxopt, and optimistix. We then show how to run Markov chain Monte Carlo (MCMC) from libraries like NumPyro, blackjax, and TFP. We conclude with variational inference (VI), as implemented in TFP.
  • Posterior analysis and validation (5 minutes): We show how to check our inference in the case of optimization or VI (inspecting loss curves with matplotlib), and MCMC (diagnostics and plots from arviz).
  • Conclusion, takeaways (5 minutes): We discuss when to use Bayesian methods, and what advantages they might bring. We also discuss the big picture of how JAX helped us here, and when other frameworks like PyTorch, TensorFlow, or just numpy might be more appropriate.

Colin Carroll is a software engineer at Google Research. In this role he focuses Bayesian computation and research, and contributes to a number of open source libraries, including bayeux, TensorFlow Probability, PyMC, and ArviZ. He received his PhD in mathematics from Rice University, where he researched geometric measure theory.