• Home
  • CV
  • Notes
  • Vis 📊

On this page

  • Tadpole survival
    • No pooling
      • Note on plotting
    • With Pooling
    • Model comparison
  • Multilevel Chimpanzees

Hierarchical modeling with PyMC and brms

Author

Sheng Long

Updated

December 19, 2025

NoteTLDR

To bridge the gap between R-based research and Python-based deployment, I compared an end-to-end Bayesian workflow using PyMC + ArviZ against the brms + tidybayes + ggplot2 stack. To make my examples concrete, I utilized the reedfrog and chimpanzee examples from McElreath’s Statistical Rethinking.

Certain things stood out to me:

  • brms abstracts away the “boilerplate” (e.g., dummy coding, index variables, non-centered parameterization) that one still needs to implement manually in PyMC.
  • Nevertheless, PyMC forces users to explicitly define the data generating process starting from the hyperpriors, and shares a closer resemblance to the underlying mathematical equations of the model. brms on the other hand, folds a lot of writing intos its formula, which could take a while to get used to, and does not make the process of specifying priors easy. PyMC also makes specifying multilevel models easier.
  • arviz introduces a more sematic way of handling data (via xarray and InferenceData), but the R ecosystem remains the superior option for visualizing data in tidy format. Grammar of Graphics makes the process of creating custom prior and posterior plots intuitive.
    • arviz is also currently undergoing some refactoring changes (per this article), and some of the plotting functions don’t quite produce the results that those who are familiar with bayesplot would expect them to look like. So technically one also need to look into arviz-plots in addition to the base arviz. plotnine, while powerful and uses grammar of graphics, does not support tidybayes geoms as of now.

All this is to say, for a researcher used to R’s high-level abstractions, the Python transition currently feels very painful — the potential for production integration is definitely there1, but the developer experience is significantly more high-friction.

I also recommend reading [this post](https://discourse.mc-stan.org/t/how-to-translate-between-pymc-and-stan/38574 and this post for people interested in the differences (as well as how to trasnlate) between PyMC and Stan.

1 e.g., via pymc.ModelBuilder

Tadpole survival

We will use the tadpole survival in ponds dataset used in McElreath’s Statistical Rethinking textbook. The reedfrogs.csv is obtained from here. One can follow along the original textbook starting from Chapter 13.1 Example: Multilevel tadpoles.

First we load the necessary libraries and data:

  • R
  • Python
set.seed(51)
library(tidyverse)
library(ggplot2)
library(brms)
library(here)
library(bayesplot)
library(tidybayes)
theme_set(theme_minimal())
df <- read.csv2(here("posts", "pymc-hierarchical", "reedfrogs.csv")) %>%
  as_tibble(.)
df %>% head(5)
# A tibble: 5 × 5
  density pred  size   surv propsurv
    <int> <chr> <chr> <int> <chr>   
1      10 no    big       9 0.9     
2      10 no    big      10 1       
3      10 no    big       7 0.7     
4      10 no    big      10 1       
5      10 no    small     9 0.9     
reticulate::py_config()
python:         /Users/shenglong/Downloads/mika-long.github.io/.venv/bin/python
libpython:      /Users/shenglong/.local/share/uv/python/cpython-3.13.5-macos-aarch64-none/lib/libpython3.13.dylib
pythonhome:     /Users/shenglong/Downloads/mika-long.github.io/.venv:/Users/shenglong/Downloads/mika-long.github.io/.venv
virtualenv:     /Users/shenglong/Downloads/mika-long.github.io/.venv/bin/activate_this.py
version:        3.13.5 (main, Jun 12 2025, 12:22:43) [Clang 20.1.4 ]
numpy:          /Users/shenglong/Downloads/mika-long.github.io/.venv/lib/python3.13/site-packages/numpy
numpy_version:  2.3.5

NOTE: Python version was forced by VIRTUAL_ENV
import pandas as pd
import os 
os.environ["PYTENSOR_FLAGS"] = "cxx="
import pymc as pm
import pytensor
import arviz as az
import arviz_plots as azp 
from pyprojroot.here import here 
az.style.use("arviz-docgrid")
df = pd.read_csv(here("posts/pymc-hierarchical/reedfrogs.csv"), delimiter=";")
df.head(5)
   density pred   size  surv  propsurv
0       10   no    big     9       0.9
1       10   no    big    10       1.0
2       10   no    big     7       0.7
3       10   no    big    10       1.0
4       10   no  small     9       0.9

We will focus on the number surviving (surv) out of an intial count (density).

No pooling

Tip

Note that for this particular example, each row is its own tank.

Math model:

\[ \begin{align} \texttt{surv}_i &\sim \text{Binomial}(\texttt{density}_i, p_i) \\ \text{logit}(p_i) &= \alpha_{\text{TANK}[i]} \\ \alpha_j &\sim \mathcal{N}(0, 1.5), j \in [48] \end{align} \]

The first step is to specify the model, specify the priors, and perform a prior predictive check:

  • R
  • Python

We use the bf function in brms to set up the model formula:

df <- df %>% mutate(tank = 1:nrow(.))

# specify formula
formula <- bf(surv | trials(density) ~ 0 + factor(tank))

We can use get_prior to check the default prior that brms has given the model:

get_prior(formula, family = binomial, data = df)
  prior class         coef group resp dpar nlpar lb ub tag       source
 (flat)     b                                                   default
 (flat)     b  factortank1                                 (vectorized)
 (flat)     b factortank10                                 (vectorized)
 (flat)     b factortank11                                 (vectorized)
 (flat)     b factortank12                                 (vectorized)
 (flat)     b factortank13                                 (vectorized)
 (flat)     b factortank14                                 (vectorized)
 (flat)     b factortank15                                 (vectorized)
 (flat)     b factortank16                                 (vectorized)
 (flat)     b factortank17                                 (vectorized)
 (flat)     b factortank18                                 (vectorized)
 (flat)     b factortank19                                 (vectorized)
 (flat)     b  factortank2                                 (vectorized)
 (flat)     b factortank20                                 (vectorized)
 (flat)     b factortank21                                 (vectorized)
 (flat)     b factortank22                                 (vectorized)
 (flat)     b factortank23                                 (vectorized)
 (flat)     b factortank24                                 (vectorized)
 (flat)     b factortank25                                 (vectorized)
 (flat)     b factortank26                                 (vectorized)
 (flat)     b factortank27                                 (vectorized)
 (flat)     b factortank28                                 (vectorized)
 (flat)     b factortank29                                 (vectorized)
 (flat)     b  factortank3                                 (vectorized)
 (flat)     b factortank30                                 (vectorized)
 (flat)     b factortank31                                 (vectorized)
 (flat)     b factortank32                                 (vectorized)
 (flat)     b factortank33                                 (vectorized)
 (flat)     b factortank34                                 (vectorized)
 (flat)     b factortank35                                 (vectorized)
 (flat)     b factortank36                                 (vectorized)
 (flat)     b factortank37                                 (vectorized)
 (flat)     b factortank38                                 (vectorized)
 (flat)     b factortank39                                 (vectorized)
 (flat)     b  factortank4                                 (vectorized)
 (flat)     b factortank40                                 (vectorized)
 (flat)     b factortank41                                 (vectorized)
 (flat)     b factortank42                                 (vectorized)
 (flat)     b factortank43                                 (vectorized)
 (flat)     b factortank44                                 (vectorized)
 (flat)     b factortank45                                 (vectorized)
 (flat)     b factortank46                                 (vectorized)
 (flat)     b factortank47                                 (vectorized)
 (flat)     b factortank48                                 (vectorized)
 (flat)     b  factortank5                                 (vectorized)
 (flat)     b  factortank6                                 (vectorized)
 (flat)     b  factortank7                                 (vectorized)
 (flat)     b  factortank8                                 (vectorized)
 (flat)     b  factortank9                                 (vectorized)

Next let us specify the prior and do a prior predictive check:

m1_prior <- brm(
  formula,
  data = df,
  family = binomial,
  prior = prior(normal(0, 1.5)),
  cores = parallel::detectCores(),
  file = here::here("posts", "pymc-hierarchical", "m1_prior.rds"),
  sample_prior = "only"
)
pp_check(m1_prior, ndraws = 100, type = "dens_overlay")
pp_check(m1_prior, ndraws = 100, type = "ecdf_overlay")

In the above graphs, \(y\) stands for observed data, while \(y_{rep}\) stands for replicated data. The default type for pp_check in bayesplot is dens_overlay, which uses KDE to smooth the data (i.e., probability density for \(y\)) into a continuous curve. ecdf_overlay refers to the empirical cumulative distribution function. We can use these to check whether the prior is “in the ballpark” of the domain knowledge before the actual model fitting. For more on when certain visualization types are appropriate for visual predictive checks in Bayesian workflow, see this paper for more details.

# define the model 
with pm.Model() as model_1:
    # prior 
    alpha = pm.Normal("alpha", 0, 1.5, shape = len(df))
    # link 
    p_survived = pm.Deterministic("p_survived", pm.math.invlogit(alpha))
    # likelihood 
    survived = pm.Binomial("survived", n=df.density, p=p_survived, observed=df.surv)

prior_idata = pm.sample_prior_predictive(model=model_1, draws=100, random_seed=51)
prior_idata
arviz.InferenceData
    • <xarray.Dataset> Size: 78kB
      Dimensions:           (chain: 1, draw: 100, p_survived_dim_0: 48,
                             alpha_dim_0: 48)
      Coordinates:
        * chain             (chain) int64 8B 0
        * draw              (draw) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99
        * p_survived_dim_0  (p_survived_dim_0) int64 384B 0 1 2 3 4 ... 43 44 45 46 47
        * alpha_dim_0       (alpha_dim_0) int64 384B 0 1 2 3 4 5 ... 42 43 44 45 46 47
      Data variables:
          p_survived        (chain, draw, p_survived_dim_0) float64 38kB 0.2359 ......
          alpha             (chain, draw, alpha_dim_0) float64 38kB -1.175 ... 2.189
      Attributes:
          created_at:                 2026-04-08T16:14:05.888657+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 1
        • draw: 100
        • p_survived_dim_0: 48
        • alpha_dim_0: 48
        • chain
          (chain)
          int64
          0
          array([0])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 6 ... 94 95 96 97 98 99
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
                 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
                 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
                 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
        • p_survived_dim_0
          (p_survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • alpha_dim_0
          (alpha_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • p_survived
          (chain, draw, p_survived_dim_0)
          float64
          0.2359 0.09098 ... 0.3466 0.8992
          array([[[0.23587841, 0.0909778 , 0.22453214, ..., 0.73413709,
                   0.61417513, 0.7409563 ],
                  [0.46936601, 0.2538041 , 0.34424533, ..., 0.96191983,
                   0.67905236, 0.41361354],
                  [0.88630901, 0.235166  , 0.38787514, ..., 0.360801  ,
                   0.76840393, 0.06936613],
                  ...,
                  [0.65419983, 0.98482144, 0.16237013, ..., 0.39503949,
                   0.18322372, 0.68610412],
                  [0.10551672, 0.18517195, 0.07832133, ..., 0.15009737,
                   0.06648612, 0.32912597],
                  [0.25518891, 0.77198423, 0.71959342, ..., 0.60623882,
                   0.34659092, 0.89922701]]], shape=(1, 100, 48))
        • alpha
          (chain, draw, alpha_dim_0)
          float64
          -1.175 -2.302 ... -0.6341 2.189
          array([[[-1.17541047, -2.30175397, -1.23944767, ...,  1.01571499,
                    0.46489656,  1.05094486],
                  [-0.12268964, -1.07842545, -0.64443219, ...,  3.22923751,
                    0.74942025, -0.34904701],
                  [ 2.05358146, -1.17936717, -0.4562528 , ..., -0.57188926,
                    1.19932074, -2.59646721],
                  ...,
                  [ 0.63755181,  4.17257652, -1.64069782, ..., -0.42617744,
                   -1.49465727,  0.78196806],
                  [-2.13737673, -1.48169223, -2.4653767 , ..., -1.7338376 ,
                   -2.64196267, -0.71214081],
                  [-1.07112653,  1.21954933,  0.94244575, ...,  0.4315294 ,
                   -0.63405821,  2.18866518]]], shape=(1, 100, 48))
      • created_at :
        2026-04-08T16:14:05.888657+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 40kB
      Dimensions:         (chain: 1, draw: 100, survived_dim_0: 48)
      Coordinates:
        * chain           (chain) int64 8B 0
        * draw            (draw) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (chain, draw, survived_dim_0) int64 38kB 3 1 2 ... 16 12 33
      Attributes:
          created_at:                 2026-04-08T16:14:05.889645+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 1
        • draw: 100
        • survived_dim_0: 48
        • chain
          (chain)
          int64
          0
          array([0])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 6 ... 94 95 96 97 98 99
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
                 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
                 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
                 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (chain, draw, survived_dim_0)
          int64
          3 1 2 1 2 0 8 ... 30 2 11 16 12 33
          array([[[ 3,  1,  2, ..., 23, 22, 26],
                  [ 3,  1,  4, ..., 34, 22, 13],
                  [10,  1,  5, ..., 13, 24,  3],
                  ...,
                  [ 8, 10,  2, ..., 15,  5, 25],
                  [ 1,  2,  1, ...,  3,  5, 10],
                  [ 1,  7,  6, ..., 16, 12, 33]]], shape=(1, 100, 48))
      • created_at :
        2026-04-08T16:14:05.889645+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 768B
      Dimensions:         (survived_dim_0: 48)
      Coordinates:
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (survived_dim_0) int64 384B 9 10 7 10 9 9 ... 14 22 12 31 17
      Attributes:
          created_at:                 2026-04-08T16:14:05.890202+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • survived_dim_0: 48
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (survived_dim_0)
          int64
          9 10 7 10 9 9 ... 13 14 22 12 31 17
          array([ 9, 10,  7, 10,  9,  9, 10,  9,  4,  9,  7,  6,  7,  5,  9,  9, 24,
                 23, 22, 25, 23, 23, 23, 21,  6, 13,  4,  9, 13, 20,  8, 10, 34, 33,
                 33, 31, 31, 35, 33, 32,  4, 12, 13, 14, 22, 12, 31, 17])
      • created_at :
        2026-04-08T16:14:05.890202+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

az.plot_ppc(prior_idata, group="prior", num_pp_samples=100, observed=True, kind="kde", random_seed=51, mean=False)
azp.plot_ppc_dist(prior_idata, group="prior_predictive", num_samples=100)
azp.plot_ppc_rootogram(prior_idata, group="prior_predictive")
<arviz_plots.plot_collection.PlotCollection object at 0x3b1a9a7b0>
<arviz_plots.plot_collection.PlotCollection object at 0x3b1d49590>

The first plot, plotted using az.plot_ppc is a bit problematic, as it does not appear smoothed as most KDE plots should be.

The second plot, plotted using az.plot_ppc_dist will raise a user warning that it “detects at least one discrete variable” and would then proceed to suggest us to plot using a variant specific for discrete data. There is no option to pass in “observed” to highlight the observed data.

The third plot, plotted using az.plot_ppc_rootogram does not raise any errors.

We could turn the data into more tidy format:

prior_idata.to_dataframe().head(5)
   chain  ...  (prior_predictive, survived[9], 9)
0      0  ...                                   3
1      0  ...                                   5
2      0  ...                                   5
3      0  ...                                   1
4      0  ...                                   5

[5 rows x 146 columns]

Next step is to fit the actual model2:

2 Note that the following corresponds to R code 13.2 in Statistical Rethinking

  • R
  • Python
m1 <- brm(
  formula,
  family = binomial,
  data = df,
  prior = prior(normal(0, 1.5), class = b),
  cores = parallel::detectCores(),
  file = here("posts", "pymc-hierarchical", "m1.rds") # prevents re-running the fitting process when quarto compiles
)
with model_1: 
    model_1_trace = pm.sample(random_seed=51, progressbar=False, nuts_sampler="numpyro")
    model_1_trace = pm.compute_log_likelihood(model_1_trace)
    # posterior predictive check 
    pm.sample_posterior_predictive(model_1_trace, extend_inferencedata=True, random_seed=51)
arviz.InferenceData
    • <xarray.Dataset> Size: 3MB
      Dimensions:           (chain: 4, draw: 1000, alpha_dim_0: 48,
                             p_survived_dim_0: 48)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * alpha_dim_0       (alpha_dim_0) int64 384B 0 1 2 3 4 5 ... 42 43 44 45 46 47
        * p_survived_dim_0  (p_survived_dim_0) int64 384B 0 1 2 3 4 ... 43 44 45 46 47
      Data variables:
          alpha             (chain, draw, alpha_dim_0) float64 2MB 3.214 ... -0.4144
          p_survived        (chain, draw, p_survived_dim_0) float64 2MB 0.9614 ... ...
      Attributes:
          created_at:                 2026-04-08T16:14:08.343292+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              0.679714
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • alpha_dim_0: 48
        • p_survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • alpha_dim_0
          (alpha_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • p_survived_dim_0
          (p_survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • alpha
          (chain, draw, alpha_dim_0)
          float64
          3.214 4.324 1.14 ... 1.22 -0.4144
          array([[[ 3.21413748e+00,  4.32386115e+00,  1.13978044e+00, ...,
                   -5.86604635e-01,  2.22205560e+00,  2.78209934e-01],
                  [ 1.93714607e+00,  2.90786267e+00,  2.29150874e+00, ...,
                   -5.73533901e-01,  2.91105088e+00,  1.38954576e-01],
                  [ 1.45319770e+00,  2.47898083e+00,  2.10374214e+00, ...,
                   -3.63143988e-01,  2.78143080e+00,  1.01617328e-01],
                  ...,
                  [ 5.61067864e-01,  1.67328399e+00,  1.05053766e+00, ...,
                   -3.62260782e-01,  1.42577407e+00,  2.32138785e-02],
                  [ 2.50958085e+00,  3.13818420e+00,  1.98146180e-01, ...,
                   -7.56960466e-01,  1.96790479e+00,  7.10454035e-02],
                  [ 5.01509444e-01,  2.16369741e+00,  2.14642261e+00, ...,
                   -3.98560045e-01,  1.74685918e+00,  2.93230431e-02]],
          
                 [[ 1.58178129e+00,  3.16858784e+00,  1.69026240e+00, ...,
                   -7.25646296e-01,  2.35988897e+00,  8.40907875e-01],
                  [ 1.15175818e+00,  1.04987943e+00,  4.12522085e-01, ...,
                   -1.15196262e+00,  1.19427777e+00, -1.24668763e-01],
                  [ 1.93224806e+00,  4.18451705e+00,  1.04873776e+00, ...,
                   -6.96405567e-01,  2.31106211e+00, -2.37805401e-01],
          ...
                   -4.66383210e-01,  1.86686841e+00,  3.88370997e-02],
                  [ 3.19447766e+00,  2.21703949e+00,  3.87080710e-01, ...,
                   -1.00336226e+00,  1.73823594e+00, -1.96658553e-01],
                  [ 3.49064036e-01,  3.09012084e+00,  1.18888943e+00, ...,
                   -4.33260046e-01,  1.73701328e+00,  3.41918443e-03]],
          
                 [[ 8.77544558e-01,  2.14531258e+00,  5.31724908e-01, ...,
                   -4.05250536e-01,  2.36398073e+00,  1.04865639e-02],
                  [ 1.02001258e+00,  1.19823361e+00,  9.03381253e-01, ...,
                   -9.40098840e-01,  1.62003603e+00,  3.71953519e-03],
                  [ 1.11590820e+00,  1.63256727e+00,  1.79278948e-01, ...,
                   -8.84648344e-01,  2.24534708e+00, -5.90842237e-02],
                  ...,
                  [ 1.71890354e+00,  2.29953896e+00,  1.43254456e+00, ...,
                   -1.12929468e+00,  2.48394264e+00,  1.86806727e-01],
                  [ 1.02141999e+00,  2.47751981e+00, -7.97448684e-02, ...,
                   -5.22931160e-01,  1.26771631e+00, -1.35056663e-01],
                  [ 1.70129521e+00,  2.06628104e+00,  1.08057168e+00, ...,
                   -6.87889791e-01,  1.22049516e+00, -4.14413906e-01]]],
                shape=(4, 1000, 48))
        • p_survived
          (chain, draw, p_survived_dim_0)
          float64
          0.9614 0.9869 ... 0.7722 0.3979
          array([[[0.96136284, 0.9869246 , 0.75763933, ..., 0.35741429,
                   0.9022127 , 0.56910731],
                  [0.87403827, 0.94823375, 0.90817135, ..., 0.3604218 ,
                   0.94839003, 0.53468286],
                  [0.81049008, 0.9226551 , 0.89126636, ..., 0.41019871,
                   0.94166409, 0.52538249],
                  ...,
                  [0.63669959, 0.84201317, 0.74087813, ..., 0.4104124 ,
                   0.80624201, 0.50580321],
                  [0.92481075, 0.95844061, 0.5493751 , ..., 0.31930655,
                   0.87738589, 0.51775388],
                  [0.62281399, 0.89694183, 0.89533401, ..., 0.40165835,
                   0.85155621, 0.50733024]],
          
                 [[0.82945664, 0.95963492, 0.84425867, ..., 0.32615084,
                   0.91371705, 0.69865639],
                  [0.75983191, 0.74075175, 0.60169247, ..., 0.24013078,
                   0.76750527, 0.46887311],
                  [0.87349804, 0.9849989 , 0.74053244, ..., 0.33260964,
                   0.90978906, 0.44082725],
          ...
                  [0.59324187, 0.91039039, 0.6539256 , ..., 0.38547265,
                   0.86609551, 0.50970805],
                  [0.96062593, 0.90176926, 0.59557974, ..., 0.26828088,
                   0.85046286, 0.4509932 ],
                  [0.58639059, 0.9564834 , 0.76654238, ..., 0.39334813,
                   0.8503073 , 0.5008548 ]],
          
                 [[0.70631313, 0.89522994, 0.62988533, ..., 0.4000515 ,
                   0.91403909, 0.50262162],
                  [0.73497505, 0.7682104 , 0.71164386, ..., 0.28088038,
                   0.8348001 , 0.50092988],
                  [0.75322894, 0.83652103, 0.54470008, ..., 0.29221545,
                   0.90424843, 0.48523324],
                  ...,
                  [0.84798755, 0.90883885, 0.80729748, ..., 0.24429129,
                   0.92300845, 0.54656634],
                  [0.7352491 , 0.92255077, 0.48007434, ..., 0.37216709,
                   0.78035157, 0.46628706],
                  [0.84570382, 0.88758242, 0.74660215, ..., 0.33450266,
                   0.77215068, 0.39785422]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:08.343292+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        0.679714
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:         (chain: 4, draw: 1000, survived_dim_0: 48)
      Coordinates:
        * chain           (chain) int64 32B 0 1 2 3
        * draw            (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (chain, draw, survived_dim_0) int64 2MB 10 10 8 ... 9 32 12
      Attributes:
          created_at:                 2026-04-08T16:14:11.511568+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (chain, draw, survived_dim_0)
          int64
          10 10 8 10 10 10 ... 17 24 9 32 12
          array([[[10, 10,  8, ..., 12, 33, 19],
                  [ 7, 10,  9, ..., 14, 35, 22],
                  [ 6,  9,  9, ..., 15, 33, 13],
                  ...,
                  [ 8,  9,  8, ..., 13, 26, 16],
                  [10,  8,  8, ..., 12, 30, 20],
                  [ 6, 10, 10, ..., 14, 31, 16]],
          
                 [[ 8,  9, 10, ..., 15, 33, 28],
                  [ 7,  5,  7, ..., 10, 25, 25],
                  [ 9, 10,  6, ..., 13, 33, 17],
                  ...,
                  [ 9, 10,  8, ..., 16, 30, 18],
                  [10,  8,  9, ..., 10, 33, 14],
                  [ 6, 10,  9, ..., 12, 29, 16]],
          
                 [[ 6,  9,  7, ..., 12, 30, 19],
                  [10,  9,  7, ..., 14, 30, 15],
                  [ 8, 10,  7, ..., 12, 33, 16],
                  ...,
                  [ 6, 10,  7, ..., 12, 30, 20],
                  [10,  8,  8, ..., 10, 32, 14],
                  [ 6, 10,  7, ..., 16, 29, 20]],
          
                 [[ 4,  9,  6, ..., 10, 30, 22],
                  [ 7,  7,  4, ...,  6, 26, 18],
                  [ 5,  9,  7, ..., 14, 31, 15],
                  ...,
                  [ 9,  9,  6, ...,  8, 32, 16],
                  [ 9,  9,  5, ..., 12, 27, 20],
                  [ 8, 10,  6, ...,  9, 32, 12]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:11.511568+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:         (chain: 4, draw: 1000, survived_dim_0: 48)
      Coordinates:
        * chain           (chain) int64 32B 0 1 2 3
        * draw            (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (chain, draw, survived_dim_0) float64 2MB -1.306 ... -2.563
      Attributes:
          created_at:                 2026-04-08T16:14:11.461503+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (chain, draw, survived_dim_0)
          float64
          -1.306 -0.1316 ... -3.066 -2.563
          array([[[-1.30558611, -0.13161634, -1.40732788, ..., -1.9758974 ,
                   -1.62399836, -2.50114121],
                  [-0.9808721 , -0.53154235, -3.05025654, ..., -1.98324445,
                   -2.63293578, -2.17837793],
                  [-1.25177443, -0.80499789, -2.67485414, ..., -2.29437808,
                   -2.36355803, -2.12046022],
                  ...,
                  [-2.77305617, -1.71959623, -1.36331256, ..., -2.29646306,
                   -2.37519418, -2.03845888],
                  [-0.98865733, -0.42447676, -1.79668501, ..., -2.00375255,
                   -1.58401401, -2.08209502],
                  [-2.93399811, -1.0876427 , -2.75736044, ..., -2.21620306,
                   -1.74568019, -2.04291756]],
          
                 [[-1.14904058, -0.4120236 , -1.97625903, ..., -1.98168491,
                   -1.73185907, -5.45155287],
                  [-1.59575351, -3.00089735, -1.53016257, ..., -2.89257902,
                   -3.17256068, -2.03011357],
                  [-0.98215693, -0.15114754, -1.3625799 , ..., -1.96788681,
                   -1.68733788, -2.1524233 ],
          ...
                  [-3.29632926, -0.93881773, -1.36914461, ..., -2.09587725,
                   -1.63317848, -2.05051141],
                  [-1.29359451, -1.03396599, -1.55595029, ..., -2.43061073,
                   -1.75615438, -2.09509446],
                  [-3.38417067, -0.4449185 , -1.43782945, ..., -2.14984135,
                   -1.75766623, -2.02625511]],
          
                 [[-2.05192552, -1.10674676, -1.42985791, ..., -2.20262193,
                   -1.73589246, -2.03021876],
                  [-1.79661474, -2.6369162 , -1.32442898, ..., -2.27936521,
                   -1.93395039, -2.02641467],
                  [-1.6471837 , -1.78503624, -1.82554465, ..., -2.1700387 ,
                   -1.63828025, -2.01022293],
                  ...,
                  [-1.06521168, -0.95587485, -1.6507725 , ..., -2.81272542,
                   -1.87396368, -2.27034956],
                  [-1.7942941 , -0.80612867, -2.31141675, ..., -2.02473069,
                   -2.88534271, -2.03670679],
                  [-1.07457086, -1.19253896, -1.37645158, ..., -1.9651145 ,
                   -3.0662274 , -2.56332926]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:11.461503+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 1.0 0.6541 ... 0.99 0.6298
          step_size        (chain, draw) float64 32kB 0.5056 0.5056 ... 0.4744 0.4744
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 211.7 221.5 ... 203.2 207.6
          n_steps          (chain, draw) int64 32kB 7 7 7 7 7 7 7 7 ... 7 7 7 7 7 7 7
          tree_depth       (chain, draw) int64 32kB 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3
          lp               (chain, draw) float64 32kB 186.7 197.9 ... 182.4 189.0
      Attributes:
          created_at:     2026-04-08T16:14:08.345370+00:00
          arviz_version:  0.22.0
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • acceptance_rate
          (chain, draw)
          float64
          1.0 0.6541 0.8935 ... 0.99 0.6298
          array([[1.        , 0.65410257, 0.89348918, ..., 0.79693971, 0.76133631,
                  0.89798145],
                 [0.93663386, 0.95897149, 0.72128395, ..., 0.95205695, 0.78336609,
                  0.92127141],
                 [1.        , 0.84371839, 0.83127066, ..., 0.75421   , 0.6989561 ,
                  0.99730171],
                 [0.88009906, 0.73844677, 0.97569455, ..., 0.86452166, 0.99001594,
                  0.62975355]], shape=(4, 1000))
        • step_size
          (chain, draw)
          float64
          0.5056 0.5056 ... 0.4744 0.4744
          array([[0.5055622 , 0.5055622 , 0.5055622 , ..., 0.5055622 , 0.5055622 ,
                  0.5055622 ],
                 [0.42719771, 0.42719771, 0.42719771, ..., 0.42719771, 0.42719771,
                  0.42719771],
                 [0.50121044, 0.50121044, 0.50121044, ..., 0.50121044, 0.50121044,
                  0.50121044],
                 [0.47438063, 0.47438063, 0.47438063, ..., 0.47438063, 0.47438063,
                  0.47438063]], shape=(4, 1000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 1000))
        • energy
          (chain, draw)
          float64
          211.7 221.5 217.8 ... 203.2 207.6
          array([[211.72974802, 221.52574906, 217.84106212, ..., 216.71134452,
                  216.15565578, 206.39346329],
                 [213.03230606, 206.65551404, 210.90044993, ..., 213.17588413,
                  220.11927657, 222.63955696],
                 [202.12163576, 204.07525924, 213.97001177, ..., 218.14522153,
                  211.2703549 , 216.37861447],
                 [209.27968524, 207.34060345, 205.94169533, ..., 207.86251159,
                  203.22287487, 207.56457017]], shape=(4, 1000))
        • n_steps
          (chain, draw)
          int64
          7 7 7 7 7 7 7 7 ... 7 7 7 7 7 7 7 7
          array([[ 7,  7,  7, ...,  7,  7,  7],
                 [ 7,  7,  7, ...,  7, 31,  7],
                 [ 7,  7,  7, ...,  7,  7,  7],
                 [ 7,  7,  7, ...,  7,  7,  7]], shape=(4, 1000))
        • tree_depth
          (chain, draw)
          int64
          3 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3 3
          array([[3, 3, 3, ..., 3, 3, 3],
                 [3, 3, 3, ..., 3, 5, 3],
                 [3, 3, 3, ..., 3, 3, 3],
                 [3, 3, 3, ..., 3, 3, 3]], shape=(4, 1000))
        • lp
          (chain, draw)
          float64
          186.7 197.9 189.2 ... 182.4 189.0
          array([[186.68298873, 197.90652396, 189.19716522, ..., 187.10779797,
                  185.14920001, 189.80541494],
                 [187.48727195, 183.22392225, 189.29132698, ..., 189.38815808,
                  193.73360495, 195.45801775],
                 [183.01335744, 188.89424728, 186.41619341, ..., 185.964496  ,
                  198.56172959, 187.55589387],
                 [181.85776435, 187.89935916, 186.55857232, ..., 184.21863152,
                  182.3844286 , 189.04365526]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:08.345370+00:00
        arviz_version :
        0.22.0

    • <xarray.Dataset> Size: 768B
      Dimensions:         (survived_dim_0: 48)
      Coordinates:
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (survived_dim_0) int64 384B 9 10 7 10 9 9 ... 14 22 12 31 17
      Attributes:
          created_at:                 2026-04-08T16:14:08.345822+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              0.679714
          tuning_steps:               1000
      xarray.Dataset
        • survived_dim_0: 48
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (survived_dim_0)
          int64
          9 10 7 10 9 9 ... 13 14 22 12 31 17
          array([ 9, 10,  7, 10,  9,  9, 10,  9,  4,  9,  7,  6,  7,  5,  9,  9, 24,
                 23, 22, 25, 23, 23, 23, 21,  6, 13,  4,  9, 13, 20,  8, 10, 34, 33,
                 33, 31, 31, 35, 33, 32,  4, 12, 13, 14, 22, 12, 31, 17])
      • created_at :
        2026-04-08T16:14:08.345822+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        0.679714
        tuning_steps :
        1000

We can check the fitted model’s summary statistics:

  • R
  • Python
summary(m1)
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 0 + factor(tank) 
   Data: df (Number of observations: 48) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
factortank1      1.71      0.78     0.30     3.38 1.00     4841     2773
factortank2      2.41      0.91     0.81     4.38 1.00     5466     2552
factortank3      0.75      0.64    -0.45     2.01 1.00     6429     2893
factortank4      2.42      0.90     0.80     4.37 1.00     5534     2801
factortank5      1.73      0.76     0.40     3.35 1.00     6133     2798
factortank6      1.72      0.77     0.36     3.34 1.00     5855     2831
factortank7      2.42      0.91     0.84     4.34 1.00     4898     2787
factortank8      1.71      0.79     0.26     3.43 1.00     6583     2902
factortank9     -0.36      0.62    -1.61     0.88 1.00     5810     3174
factortank10     1.70      0.76     0.31     3.31 1.00     5849     3072
factortank11     0.75      0.62    -0.45     2.01 1.00     5113     2680
factortank12     0.37      0.60    -0.82     1.57 1.00     5141     2880
factortank13     0.75      0.64    -0.49     2.07 1.00     5751     3041
factortank14     0.00      0.60    -1.19     1.22 1.00     5740     2880
factortank15     1.71      0.78     0.28     3.35 1.00     5165     2245
factortank16     1.72      0.77     0.31     3.30 1.00     6319     2450
factortank17     2.56      0.69     1.36     4.00 1.00     5755     3144
factortank18     2.13      0.59     1.09     3.40 1.00     5396     2874
factortank19     1.80      0.55     0.80     2.95 1.00     5799     2588
factortank20     3.12      0.81     1.70     4.86 1.00     5290     2642
factortank21     2.12      0.61     1.04     3.44 1.00     5753     2748
factortank22     2.13      0.60     1.04     3.41 1.00     4926     2737
factortank23     2.13      0.59     1.11     3.40 1.00     5311     2911
factortank24     1.54      0.50     0.62     2.58 1.00     6062     2666
factortank25    -1.10      0.44    -1.98    -0.28 1.00     5571     3009
factortank26     0.08      0.39    -0.69     0.84 1.00     5045     2869
factortank27    -1.54      0.50    -2.56    -0.60 1.00     5488     2826
factortank28    -0.55      0.40    -1.36     0.21 1.00     6429     2979
factortank29     0.08      0.40    -0.69     0.85 1.00     6081     2796
factortank30     1.31      0.47     0.44     2.29 1.00     5937     2799
factortank31    -0.73      0.42    -1.54     0.06 1.00     5521     2995
factortank32    -0.39      0.40    -1.19     0.39 1.00     6622     3221
factortank33     2.83      0.65     1.70     4.20 1.00     5594     2751
factortank34     2.46      0.59     1.36     3.72 1.00     6757     2890
factortank35     2.46      0.57     1.44     3.70 1.00     5713     2964
factortank36     1.91      0.50     0.98     2.93 1.00     6101     2918
factortank37     1.91      0.48     1.04     2.90 1.00     5749     3220
factortank38     3.37      0.80     1.98     5.09 1.00     5270     2635
factortank39     2.47      0.58     1.45     3.72 1.00     6428     2991
factortank40     2.17      0.53     1.23     3.28 1.00     4563     2679
factortank41    -1.91      0.49    -2.93    -1.02 1.00     6737     2697
factortank42    -0.63      0.36    -1.35     0.06 1.00     6029     2859
factortank43    -0.51      0.35    -1.21     0.16 1.00     6284     2828
factortank44    -0.39      0.33    -1.04     0.25 1.00     6155     3038
factortank45     0.51      0.35    -0.16     1.23 1.00     5068     2723
factortank46    -0.63      0.35    -1.33     0.05 1.00     6318     2773
factortank47     1.92      0.50     1.02     2.97 1.00     6032     2925
factortank48    -0.06      0.34    -0.72     0.58 1.00     6955     2825

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
az.summary(model_1_trace)
                 mean     sd  hdi_3%  ...  ess_bulk  ess_tail  r_hat
alpha[0]        1.695  0.759   0.349  ...    5499.0    2681.0    1.0
alpha[1]        2.410  0.903   0.767  ...    5365.0    2646.0    1.0
alpha[2]        0.760  0.638  -0.446  ...    5838.0    2914.0    1.0
alpha[3]        2.410  0.882   0.865  ...    4956.0    2677.0    1.0
alpha[4]        1.704  0.775   0.293  ...    5890.0    2768.0    1.0
...               ...    ...     ...  ...       ...       ...    ...
p_survived[43]  0.404  0.079   0.258  ...    5575.0    3215.0    1.0
p_survived[44]  0.623  0.077   0.473  ...    5024.0    2826.0    1.0
p_survived[45]  0.350  0.075   0.214  ...    5269.0    2997.0    1.0
p_survived[46]  0.861  0.056   0.761  ...    5321.0    2415.0    1.0
p_survived[47]  0.487  0.080   0.335  ...    5712.0    2993.0    1.0

[96 rows x 9 columns]

We can also plot the posterior distribution as well as the trace plot:

  • R
  • Python
plot(m1, variable = variables(m1)[1:5])

# Pass only the first 5 names to plot_trace
az.plot_trace(model_1_trace, combined=True, var_names="alpha")
array([[<Axes: title={'center': 'alpha'}>,
        <Axes: title={'center': 'alpha'}>]], dtype=object)

Or something like the “forest plot”:

  • R
  • Python
m1 %>% as_draws_df(.) %>% mcmc_intervals(regex_pars = "b_.*")

az.plot_forest(model_1_trace, kind='forestplot', combined=True, var_names="alpha", hdi_prob=0.95)
array([<Axes: title={'center': '95.0% HDI'}>], dtype=object)

Note that mcmc_intervals plots the “central (quantile-based) posterior interval estimates from MCMC draws”, while plot_forest displays the credible intervals, where “the central points are the estimated posterior median, the thick lines are the central quantiles, and the thin lines represent the \(100 \times (hdi_prob)\%\) highest density interval”. The former is about equal-tailed intervals (ETI) while the later is about highest density intervals (HDI). The difference between ETI and HDI is negligable between symmetric distribtuions (e.g., Gaussian Normal), but it becomes big when the posterior distributions are skewed. For most common purposes, such as writing about posterior distribtuions in academic journals, most people use ETI because it’s conceptually easier to explain.

We can also perform a posterior predictive check:

  • R
  • Python
pp_check(m1, ndraws = 100)
pp_check(m1, ndraws = 100, type = "ecdf_overlay")

# az.plot_ppc(model_1_trace, num_pp_samples=100)   
azp.plot_ppc_dist(model_1_trace, group="posterior_predictive", num_samples=100)
<arviz_plots.plot_collection.PlotCollection object at 0x43c8c3890>

Note on plotting

We can also go beyond the default plots in bayesplot and construct our own plots using ggplot2 and tidybayes. To do that, we first need to get the posterior draws in a tidy format to make it easier for plotting. Observe that the default as_data_frame function does not tell us information related to the chain and draw:

m1 %>% as_tibble(.)
# A tibble: 4,000 × 50
   b_factortank1 b_factortank2 b_factortank3 b_factortank4 b_factortank5
           <dbl>         <dbl>         <dbl>         <dbl>         <dbl>
 1          1.20          1.80        1.77            2.66         2.55 
 2          1.94          2.68       -0.273           1.79         1.38 
 3          1.93          1.97        2.31            2.62         2.15 
 4          2.66          1.71        1.21            1.04         2.24 
 5          2.12          4.06        0.738           3.80         1.15 
 6          1.36          2.74       -0.0728          4.15         1.33 
 7          1.20          3.36       -0.475           1.50         2.59 
 8          2.94          2.08        1.78            1.54         0.891
 9          1.34          2.26       -0.556           3.60         1.69 
10          2.05          3.51        2.08            1.10         1.76 
# ℹ 3,990 more rows
# ℹ 45 more variables: b_factortank6 <dbl>, b_factortank7 <dbl>,
#   b_factortank8 <dbl>, b_factortank9 <dbl>, b_factortank10 <dbl>,
#   b_factortank11 <dbl>, b_factortank12 <dbl>, b_factortank13 <dbl>,
#   b_factortank14 <dbl>, b_factortank15 <dbl>, b_factortank16 <dbl>,
#   b_factortank17 <dbl>, b_factortank18 <dbl>, b_factortank19 <dbl>,
#   b_factortank20 <dbl>, b_factortank21 <dbl>, b_factortank22 <dbl>, …

spread_draws, on the other hand, does provide this information:

m1 %>%
  spread_draws(`b_factortank.*`, regex = TRUE) %>%
  # plot the top 5 for now
  select(1:8) %>%
  pivot_longer(4:8) %>%
  ggplot(aes(y = name, x = value)) +
  stat_halfeye()

With Pooling

Math model: \[ \begin{align} \texttt{surv}_i &\sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) &\sim \alpha_{\text{TANK}[i]} \\ \alpha_j &\sim \mathcal{N}(\bar{\alpha}, \sigma), j \in [48] \\ \bar{\alpha} &\sim \mathcal{N}(0, 1.5) \\ \sigma &\sim \text{Exponential}(1) \end{align} \]

Here we will skip the prior predictive part and directly go into the model building part3:

3 Note that the following corresponds to R code 13.3 in Statistical Rethinking

  • R
  • Python
m2 <- brm(
  data = df,
  family = binomial,
  formula = bf(surv | trials(density) ~ 1 + (1 | tank)),
  prior = c(
    prior(normal(0, 1.5), class = Intercept),
    prior(exponential(1), class = sd)
  ),
  cores = parallel::detectCores(),
  seed = 51,
  file = here("posts", "pymc-hierarchical", "m2.rds")
)

To get a sense of the default priors that brms gave us:

get_prior(
  bf(surv | trials(density) ~ 1 + (1 | tank)),
  family = binomial,
  data = df
)
                prior     class      coef group resp dpar nlpar lb ub tag
 student_t(3, 0, 2.5) Intercept                                          
 student_t(3, 0, 2.5)        sd                                  0       
 student_t(3, 0, 2.5)        sd            tank                  0       
 student_t(3, 0, 2.5)        sd Intercept  tank                  0       
       source
      default
      default
 (vectorized)
 (vectorized)
with pm.Model() as model_2:
    # hyperprior 
    sigma = pm.Exponential("sigma", 1)
    alpha_bar = pm.Normal("alpha_bar", 0, 1.5)
    # prior 
    alpha = pm.Normal("alpha", mu=alpha_bar, sigma=sigma, shape=len(df))
    # link 
    p_survived = pm.Deterministic("p_survived", pm.math.invlogit(alpha))
    # likelihood 
    survived = pm.Binomial("survived", n=df.density, p=p_survived, observed=df.surv)

    model_2_trace = pm.sample(progressbar=False, nuts_sampler="numpyro")
    model_2_trace = pm.compute_log_likelihood(model_2_trace)
    # posterior predictive check 
    pm.sample_posterior_predictive(model_2_trace, extend_inferencedata=True)
arviz.InferenceData
    • <xarray.Dataset> Size: 3MB
      Dimensions:           (chain: 4, draw: 1000, alpha_dim_0: 48,
                             p_survived_dim_0: 48)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * alpha_dim_0       (alpha_dim_0) int64 384B 0 1 2 3 4 5 ... 42 43 44 45 46 47
        * p_survived_dim_0  (p_survived_dim_0) int64 384B 0 1 2 3 4 ... 43 44 45 46 47
      Data variables:
          alpha_bar         (chain, draw) float64 32kB 1.011 1.56 ... 1.401 1.614
          alpha             (chain, draw, alpha_dim_0) float64 2MB 1.167 ... -0.03433
          sigma             (chain, draw) float64 32kB 1.6 1.72 1.562 ... 1.575 1.562
          p_survived        (chain, draw, p_survived_dim_0) float64 2MB 0.7626 ... ...
      Attributes:
          created_at:                 2026-04-08T16:14:16.101184+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              0.742881
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • alpha_dim_0: 48
        • p_survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • alpha_dim_0
          (alpha_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • p_survived_dim_0
          (p_survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • alpha_bar
          (chain, draw)
          float64
          1.011 1.56 1.242 ... 1.401 1.614
          array([[1.01050354, 1.56006333, 1.24248609, ..., 0.8157957 , 1.94771118,
                  0.91948859],
                 [0.80204901, 1.69540099, 1.07755515, ..., 1.44733484, 1.57598598,
                  1.25450943],
                 [1.04351521, 1.42797399, 1.58074554, ..., 1.07691064, 1.3666307 ,
                  1.08226613],
                 [1.53885616, 1.40181362, 1.11421207, ..., 1.39786801, 1.40113134,
                  1.61399483]], shape=(4, 1000))
        • alpha
          (chain, draw, alpha_dim_0)
          float64
          1.167 3.002 ... 3.061 -0.03433
          array([[[ 1.16674113,  3.00157635,  0.56343468, ..., -1.06151546,
                    1.48067286, -0.26028341],
                  [ 2.8562348 ,  2.30470842,  1.22003169, ..., -0.26021267,
                    2.40337387,  0.22970391],
                  [ 1.44601478,  2.81257662,  0.56336557, ..., -0.48945803,
                    1.16602867, -0.21725693],
                  ...,
                  [ 1.79989804,  2.81665438,  0.17882807, ..., -0.26804521,
                    2.57114405, -0.26512872],
                  [ 2.01948396,  2.8350586 ,  1.81611263, ..., -0.86297547,
                    1.98460516,  0.05917535],
                  [ 1.90827042,  2.93215169,  0.04476954, ..., -0.11501283,
                    1.86762602, -0.10488334]],
          
                 [[ 1.94236817,  2.73020323,  0.37275029, ..., -0.67971871,
                    1.83600607,  0.56701255],
                  [ 1.69209176,  2.59259738,  1.10406985, ..., -0.56175603,
                    2.0728445 , -0.37921758],
                  [ 1.86336417,  2.76047516,  0.87608698, ..., -0.80931555,
                    1.99355544,  0.06160896],
          ...
                  [ 0.62555251,  3.31790714,  0.18090971, ...,  0.1028798 ,
                    1.97793098, -0.39518693],
                  [ 1.57948415,  4.18201329,  0.90850498, ..., -0.49792996,
                    1.93078576, -0.29308024],
                  [ 1.82110235,  2.82838474,  0.77552717, ..., -0.53616389,
                    1.37246225, -0.46773979]],
          
                 [[ 2.15372508,  4.4744122 ,  0.28402776, ..., -0.27731364,
                    2.29962625,  0.03260459],
                  [ 1.01255289,  2.07182033,  0.63664913, ..., -0.7507426 ,
                    1.89979507,  0.23719651],
                  [ 2.88628221,  3.83081823,  1.17008637, ..., -0.57272523,
                    1.84664724, -0.29596668],
                  ...,
                  [ 0.49578707,  1.59034239,  2.45273355, ..., -0.25505449,
                    2.68658773,  0.01852404],
                  [ 3.28106447,  4.8510809 ,  0.03024226, ..., -0.43659646,
                    1.62314885, -0.39442873],
                  [ 2.01730922,  0.74394216,  1.06295818, ..., -1.03869237,
                    3.06091511, -0.03433479]]], shape=(4, 1000, 48))
        • sigma
          (chain, draw)
          float64
          1.6 1.72 1.562 ... 1.575 1.562
          array([[1.60024992, 1.72038647, 1.56229692, ..., 1.5875813 , 1.65084448,
                  1.84615997],
                 [1.84419259, 1.37557293, 1.82397356, ..., 1.71397804, 1.80102562,
                  1.68776646],
                 [1.33613067, 1.52718293, 1.50807641, ..., 1.63335429, 1.56699475,
                  1.44257623],
                 [1.71974347, 1.59724794, 1.32913574, ..., 1.90421182, 1.57476484,
                  1.56155216]], shape=(4, 1000))
        • p_survived
          (chain, draw, p_survived_dim_0)
          float64
          0.7626 0.9526 ... 0.9553 0.4914
          array([[[0.76255546, 0.95264529, 0.63724689, ..., 0.25701996,
                   0.81467419, 0.43529404],
                  [0.94564007, 0.90926624, 0.77206913, ..., 0.43531143,
                   0.91708422, 0.5571748 ],
                  [0.80938435, 0.94335167, 0.63723091, ..., 0.38002125,
                   0.76242643, 0.4458984 ],
                  ...,
                  [0.85813652, 0.94356919, 0.54458825, ..., 0.43338706,
                   0.92898121, 0.43410338],
                  [0.88282764, 0.94454118, 0.86009902, ..., 0.29671806,
                   0.87917122, 0.51478952],
                  [0.87082471, 0.94941312, 0.51119052, ..., 0.47127845,
                   0.86618335, 0.47380318]],
          
                 [[0.87461208, 0.93878552, 0.59212338, ..., 0.33632409,
                   0.86247567, 0.63807355],
                  [0.84449905, 0.93038364, 0.7510219 , ..., 0.36314125,
                   0.88823566, 0.40631562],
                  [0.86568859, 0.94050223, 0.70601069, ..., 0.30803637,
                   0.88011878, 0.51539737],
          ...
                  [0.65148032, 0.96503805, 0.54510448, ..., 0.52569729,
                   0.87846043, 0.40246928],
                  [0.82913145, 0.98496186, 0.71269414, ..., 0.37802726,
                   0.87333637, 0.42724994],
                  [0.86069835, 0.94419055, 0.68471532, ..., 0.36908041,
                   0.79777768, 0.38515134]],
          
                 [[0.89601636, 0.98873151, 0.5705334 , ..., 0.4311125 ,
                   0.90884608, 0.50815043],
                  [0.73351946, 0.88813394, 0.6539956 , ..., 0.32065951,
                   0.86986833, 0.55902266],
                  [0.94716414, 0.97876869, 0.76316063, ..., 0.36060823,
                   0.86373297, 0.42654376],
                  ...,
                  [0.62146877, 0.83066427, 0.92076112, ..., 0.43657981,
                   0.93623056, 0.50463088],
                  [0.96377347, 0.99224076, 0.50755999, ..., 0.39255226,
                   0.83522894, 0.40265163],
                  [0.88260249, 0.6778573 , 0.74325545, ..., 0.26140238,
                   0.95525143, 0.49141715]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:16.101184+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        0.742881
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:         (chain: 4, draw: 1000, survived_dim_0: 48)
      Coordinates:
        * chain           (chain) int64 32B 0 1 2 3
        * draw            (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (chain, draw, survived_dim_0) int64 2MB 7 9 7 9 ... 5 32 14
      Attributes:
          created_at:                 2026-04-08T16:14:19.310521+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (chain, draw, survived_dim_0)
          int64
          7 9 7 9 8 9 8 ... 18 16 19 5 32 14
          array([[[ 7,  9,  7, ...,  9, 27, 13],
                  [10,  9,  5, ..., 17, 33, 17],
                  [10, 10,  8, ..., 17, 27, 12],
                  ...,
                  [ 9,  9,  7, ..., 14, 32, 15],
                  [ 8, 10,  9, ..., 11, 32, 19],
                  [ 8,  9,  4, ..., 11, 29, 18]],
          
                 [[10,  9,  7, ..., 13, 31, 25],
                  [10,  8,  9, ..., 10, 30, 16],
                  [ 9,  9,  5, ..., 12, 33, 19],
                  ...,
                  [10, 10,  9, ..., 16, 30, 16],
                  [ 9,  9,  9, ..., 11, 30, 15],
                  [ 9,  6,  6, ...,  9, 32, 14]],
          
                 [[ 8,  9,  9, ..., 14, 33, 12],
                  [10,  9,  7, ..., 13, 27, 20],
                  [ 9, 10,  9, ..., 12, 31, 17],
                  ...,
                  [ 7, 10,  6, ..., 16, 32, 16],
                  [ 9,  9,  6, ..., 14, 26, 14],
                  [ 7,  9,  6, ..., 14, 24, 14]],
          
                 [[10, 10,  8, ..., 18, 32, 16],
                  [ 8, 10,  7, ..., 10, 31, 19],
                  [ 9, 10,  9, ...,  9, 29, 15],
                  ...,
                  [ 6,  9,  8, ..., 10, 33, 17],
                  [10, 10,  5, ..., 10, 28, 10],
                  [ 8,  7,  8, ...,  5, 32, 14]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:19.310521+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:         (chain: 4, draw: 1000, survived_dim_0: 48)
      Coordinates:
        * chain           (chain) int64 32B 0 1 2 3
        * draw            (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (chain, draw, survived_dim_0) float64 2MB -1.575 ... -2.012
      Attributes:
          created_at:                 2026-04-08T16:14:19.252184+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • survived_dim_0: 48
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (chain, draw, survived_dim_0)
          float64
          -1.575 -0.4851 ... -2.98 -2.012
          array([[[-1.57495649, -0.48512648, -1.40879352, ..., -2.59391318,
                   -2.23063862, -2.18991487],
                  [-1.11258224, -0.95117336, -1.45941524, ..., -2.5820981 ,
                   -1.77705628, -2.36968259],
                  [-1.25824354, -0.5831614 , -1.4088369 , ..., -2.06366348,
                   -3.29194062, -2.12196356],
                  ...,
                  [-1.02723369, -0.58085585, -1.82624513, ..., -2.55701663,
                   -1.99701474, -2.1985664 ],
                  [-0.96315183, -0.57055989, -2.16792373, ..., -2.13332961,
                   -1.57966882, -2.06939985],
                  [-0.98883102, -0.51911258, -2.05694617, ..., -3.14313563,
                   -1.63265942, -2.02015791]],
          
                 [[-0.97953134, -0.63168242, -1.57116169, ..., -1.96298556,
                   -1.65631804, -3.69612906],
                  [-1.07962336, -0.72158265, -1.38792253, ..., -1.99104525,
                   -1.57361654, -2.46030264],
                  [-1.00307941, -0.61341261, -1.32201818, ..., -2.0572668 ,
                   -1.57776773, -2.07190226],
          ...
                  [-2.60804826, -0.3558775 , -1.82301537, ..., -4.32999041,
                   -1.58128021, -2.50575617],
                  [-1.15066479, -0.15152359, -1.32505249, ..., -2.0529394 ,
                   -1.59745226, -2.25241205],
                  [-1.01862913, -0.57427282, -1.32661099, ..., -2.01187112,
                   -2.53133706, -2.73918903]],
          
                 [[-0.94910607, -0.11332464, -1.67642712, ..., -2.52801817,
                   -1.67789024, -2.04544735],
                  [-1.80897936, -1.18632711, -1.36900221, ..., -1.99877451,
                   -1.61275189, -2.38866482],
                  [-1.1265259 , -0.21459939, -1.42563486, ..., -1.98374422,
                   -1.64789746, -2.25835412],
                  ...,
                  [-2.9498982 , -1.85529573, -3.39625537, ..., -2.59890394,
                   -2.1867254 , -2.03525761],
                  [-1.34746943, -0.07789503, -2.08463874, ..., -2.14399176,
                   -1.92842673, -2.50354952],
                  [-0.96352774, -3.88818492, -1.36853762, ..., -2.52709289,
                   -2.98008655, -2.01248432]]], shape=(4, 1000, 48))
      • created_at :
        2026-04-08T16:14:19.252184+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.9992 0.791 ... 0.9348 0.8228
          step_size        (chain, draw) float64 32kB 0.3962 0.3962 ... 0.4121 0.4121
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 202.7 209.1 ... 212.4 227.9
          n_steps          (chain, draw) int64 32kB 15 15 15 7 15 7 ... 7 15 7 15 7 31
          tree_depth       (chain, draw) int64 32kB 4 4 4 3 4 3 3 3 ... 3 3 4 3 4 3 5
          lp               (chain, draw) float64 32kB 178.4 179.8 ... 190.5 188.6
      Attributes:
          created_at:     2026-04-08T16:14:16.103610+00:00
          arviz_version:  0.22.0
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • acceptance_rate
          (chain, draw)
          float64
          0.9992 0.791 ... 0.9348 0.8228
          array([[0.99921382, 0.79096053, 0.80154342, ..., 0.88383451, 1.        ,
                  0.44439483],
                 [0.89218431, 0.97916225, 0.70042878, ..., 0.99303536, 0.64280577,
                  0.69968134],
                 [0.99342144, 0.83718374, 0.65611432, ..., 0.93528565, 0.99395124,
                  0.74820161],
                 [0.92763682, 0.89012969, 0.89244148, ..., 0.83889   , 0.93484342,
                  0.82278311]], shape=(4, 1000))
        • step_size
          (chain, draw)
          float64
          0.3962 0.3962 ... 0.4121 0.4121
          array([[0.39623754, 0.39623754, 0.39623754, ..., 0.39623754, 0.39623754,
                  0.39623754],
                 [0.47625751, 0.47625751, 0.47625751, ..., 0.47625751, 0.47625751,
                  0.47625751],
                 [0.40008515, 0.40008515, 0.40008515, ..., 0.40008515, 0.40008515,
                  0.40008515],
                 [0.41214465, 0.41214465, 0.41214465, ..., 0.41214465, 0.41214465,
                  0.41214465]], shape=(4, 1000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 1000))
        • energy
          (chain, draw)
          float64
          202.7 209.1 211.5 ... 212.4 227.9
          array([[202.73017967, 209.07070818, 211.48761621, ..., 201.78021414,
                  201.51179888, 227.63341175],
                 [209.38650101, 198.46952305, 201.93408484, ..., 191.63987463,
                  197.35093376, 200.71405609],
                 [193.83768005, 192.55586858, 199.05469885, ..., 206.53768221,
                  198.89685869, 202.03858451],
                 [198.73516296, 198.07843294, 203.15223316, ..., 217.72884115,
                  212.36130443, 227.93420714]], shape=(4, 1000))
        • n_steps
          (chain, draw)
          int64
          15 15 15 7 15 7 ... 7 15 7 15 7 31
          array([[15, 15, 15, ..., 15, 15,  7],
                 [ 7,  7,  7, ...,  7,  7,  7],
                 [15, 15,  7, ..., 15,  7,  7],
                 [ 7, 23,  7, ..., 15,  7, 31]], shape=(4, 1000))
        • tree_depth
          (chain, draw)
          int64
          4 4 4 3 4 3 3 3 ... 4 3 3 4 3 4 3 5
          array([[4, 4, 4, ..., 4, 4, 3],
                 [3, 3, 3, ..., 3, 3, 3],
                 [4, 4, 3, ..., 4, 3, 3],
                 [3, 5, 3, ..., 4, 3, 5]], shape=(4, 1000))
        • lp
          (chain, draw)
          float64
          178.4 179.8 183.7 ... 190.5 188.6
          array([[178.42447462, 179.76674026, 183.71980527, ..., 179.96718581,
                  178.63780721, 191.48611643],
                 [180.13069193, 173.33286665, 184.79093995, ..., 171.34461737,
                  175.74283405, 182.36954801],
                 [170.24187258, 171.00173879, 169.52293104, ..., 184.00725198,
                  171.07124538, 177.45963493],
                 [174.22093332, 175.58849234, 176.43852395, ..., 185.39070012,
                  190.46220565, 188.60913083]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:16.103610+00:00
        arviz_version :
        0.22.0

    • <xarray.Dataset> Size: 768B
      Dimensions:         (survived_dim_0: 48)
      Coordinates:
        * survived_dim_0  (survived_dim_0) int64 384B 0 1 2 3 4 5 ... 43 44 45 46 47
      Data variables:
          survived        (survived_dim_0) int64 384B 9 10 7 10 9 9 ... 14 22 12 31 17
      Attributes:
          created_at:                 2026-04-08T16:14:16.104093+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              0.742881
          tuning_steps:               1000
      xarray.Dataset
        • survived_dim_0: 48
        • survived_dim_0
          (survived_dim_0)
          int64
          0 1 2 3 4 5 6 ... 42 43 44 45 46 47
          array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
        • survived
          (survived_dim_0)
          int64
          9 10 7 10 9 9 ... 13 14 22 12 31 17
          array([ 9, 10,  7, 10,  9,  9, 10,  9,  4,  9,  7,  6,  7,  5,  9,  9, 24,
                 23, 22, 25, 23, 23, 23, 21,  6, 13,  4,  9, 13, 20,  8, 10, 34, 33,
                 33, 31, 31, 35, 33, 32,  4, 12, 13, 14, 22, 12, 31, 17])
      • created_at :
        2026-04-08T16:14:16.104093+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        0.742881
        tuning_steps :
        1000

We can check the summary as well as the traces:

  • R
  • Python
summary(m2)
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 1 + (1 | tank) 
   Data: df (Number of observations: 48) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~tank (Number of levels: 48) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     1.61      0.21     1.25     2.08 1.00      901     1890

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     1.34      0.26     0.86     1.86 1.00      652     1361

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
plot(m2)

az.summary(model_2_trace)
                 mean     sd  hdi_3%  ...  ess_bulk  ess_tail  r_hat
alpha_bar       1.348  0.257   0.865  ...    3778.0    3160.0    1.0
alpha[0]        2.134  0.868   0.539  ...    5265.0    2732.0    1.0
alpha[1]        3.048  1.105   1.111  ...    4124.0    2419.0    1.0
alpha[2]        0.995  0.658  -0.175  ...    5426.0    2658.0    1.0
alpha[3]        3.078  1.113   1.077  ...    4642.0    2827.0    1.0
...               ...    ...     ...  ...       ...       ...    ...
p_survived[43]  0.420  0.082   0.274  ...    5161.0    2791.0    1.0
p_survived[44]  0.638  0.078   0.483  ...    5089.0    2864.0    1.0
p_survived[45]  0.366  0.078   0.222  ...    5215.0    2933.0    1.0
p_survived[46]  0.877  0.053   0.778  ...    5128.0    2584.0    1.0
p_survived[47]  0.501  0.082   0.351  ...    4374.0    2872.0    1.0

[98 rows x 9 columns]

Model comparison

Now that we have two models for the tadpole dataset, we can compare them using …

  • R
  • Python

… the loo package in R:

loo(m1, m2)
Output of model 'm1':

Computed from 4000 by 48 log-likelihood matrix.

         Estimate  SE
elpd_loo   -121.5 2.4
p_loo        39.8 2.0
looic       243.0 4.9
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.4, 1.5]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)      7    14.6%   186     
   (0.7, 1]   (bad)      34    70.8%   <NA>    
   (1, Inf)   (very bad)  7    14.6%   <NA>    
See help('pareto-k-diagnostic') for details.

Output of model 'm2':

Computed from 4000 by 48 log-likelihood matrix.

         Estimate  SE
elpd_loo   -110.6 4.1
p_loo        31.5 1.5
looic       221.2 8.3
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume MCMC draws (r_eff in [0.4, 1.5]).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)      6    12.5%   197     
   (0.7, 1]   (bad)      39    81.2%   <NA>    
   (1, Inf)   (very bad)  3     6.2%   <NA>    
See help('pareto-k-diagnostic') for details.

Model comparisons:
   elpd_diff se_diff
m2   0.0       0.0  
m1 -10.9       2.9  
az.compare({'unpooled': model_1_trace, 'pooled': model_2_trace})
          rank    elpd_loo      p_loo  ...       dse  warning  scale
pooled       0 -111.924512  32.775933  ...  0.000000     True    log
unpooled     1 -120.078198  38.335198  ...  2.897545     True    log

[2 rows x 9 columns]

Multilevel Chimpanzees

Next we will explore the chimpanzee dataset for something a bit more complex to demonstrate the power of the PyMC syntax. The chimpanzee data is obtained from here. The example we’re looking at is 13.3.1 Multilevel chimpanzees. We will create a new column, treatment, that is a combination of prosoc_left and condition.

Here’s the math model: \[ \begin{align} \texttt{pulled\_left}_i &\sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) &= \alpha_{\texttt{actor}[i]} + \gamma_{\texttt{block}[i]} + \beta_{\texttt{treatment}[i]} \\ \beta_j &\sim \mathcal{N}(0, 0.5), j \in [4] \\ \alpha_j &\sim \mathcal{N}(\bar{\alpha}, \sigma_\alpha), j \in [7] \\ \gamma_j &\sim \mathcal{N}(0, \sigma_\gamma), j \in [6] \\ \bar{\alpha} &\sim \mathcal{N}(0, 1.5) \\ \sigma_\alpha &\sim \text{Exponential}(1) \\ \sigma_\gamma &\sim \text{Exponential}(1) \end{align} \]

  • R
  • Python
df <- read.csv2(
  here("posts", "pymc-hierarchical", "chimpanzees.csv"),
  sep = ";"
) %>%
  as_tibble(.) %>%
  mutate(
    treatment = paste0(as.character(condition), "_", as.character(prosoc_left))
  )

df %>% head(5)
# A tibble: 5 × 9
  actor recipient condition block trial prosoc_left chose_prosoc pulled_left
  <int>     <int>     <int> <int> <int>       <int>        <int>       <int>
1     1        NA         0     1     2           0            1           0
2     1        NA         0     1     4           0            0           1
3     1        NA         0     1     6           1            0           0
4     1        NA         0     1     8           0            1           0
5     1        NA         0     1    10           1            1           1
# ℹ 1 more variable: treatment <chr>
# read in the data 
df = pd.read_csv(here("posts/pymc-hierarchical/chimpanzees.csv"), delimiter=";")
# create a new treatment column 
df['treatment'] = df['condition'].astype(str) + "_" + df['prosoc_left'].astype(str)

df.head(5)
   actor  recipient  condition  ...  chose_prosoc  pulled_left  treatment
0      1        NaN          0  ...             1            0        0_0
1      1        NaN          0  ...             0            1        0_0
2      1        NaN          0  ...             0            0        0_1
3      1        NaN          0  ...             1            0        0_0
4      1        NaN          0  ...             1            1        0_1

[5 rows x 9 columns]

Define and fit the model:

  • R
  • Python
m3 <- brm(
  data = df,
  family = binomial,
  bf(
    pulled_left | trials(1) ~ a + b,
    a ~ 1 + (1 | actor) + (1 | block),
    b ~ 0 + treatment,
    nl = TRUE
  ),
  prior = c(
    prior(normal(0, 0.5), nlpar = b),
    prior(normal(0, 1.5), class = b, coef = Intercept, nlpar = a),
    prior(exponential(1), class = sd, group = actor, nlpar = a),
    prior(exponential(1), class = sd, group = block, nlpar = a)
  ),
  cores = parallel::detectCores(),
  seed = 51,
  file = here("posts", "pymc-hierarchical", "m3.rds")
)

Note here that we are using the non-linear way of specifying the formula4, and again we can use get_prior to get a sense of how to specify the priors:

4 See this vignette for more on fitting non-linear models with brms

get_prior(
  bf(
    pulled_left | trials(1) ~ a + b,
    a ~ 1 + (1 | actor) + (1 | block),
    b ~ 0 + treatment,
    nl = TRUE
  ),
  family = binomial,
  data = df
)
                prior class         coef group resp dpar nlpar lb ub tag
               (flat)     b                                  a          
               (flat)     b    Intercept                     a          
 student_t(3, 0, 2.5)    sd                                  a  0       
 student_t(3, 0, 2.5)    sd              actor               a  0       
 student_t(3, 0, 2.5)    sd    Intercept actor               a  0       
 student_t(3, 0, 2.5)    sd              block               a  0       
 student_t(3, 0, 2.5)    sd    Intercept block               a  0       
               (flat)     b                                  b          
               (flat)     b treatment0_0                     b          
               (flat)     b treatment0_1                     b          
               (flat)     b treatment1_0                     b          
               (flat)     b treatment1_1                     b          
       source
      default
 (vectorized)
      default
 (vectorized)
 (vectorized)
 (vectorized)
 (vectorized)
      default
 (vectorized)
 (vectorized)
 (vectorized)
 (vectorized)
coords = {
    'actor': df.actor.unique(), 
    'block': df.block.unique(),
    'treatment': df.treatment.unique(),
}
actor_idx, actor_codes = pd.factorize(df['actor'])
block_idx, block_codes = pd.factorize(df['block'])
treat_idx, treat_codes = pd.factorize(df['treatment'])
# define and fit model 
with pm.Model(coords=coords) as chimp_model:
    # hyperpriors 
    a_bar = pm.Normal("a_bar", 0, 1.5)
    sigma_a = pm.Exponential("sigma_a", 1)
    sigma_gamma = pm.Exponential("sigma_gamma", 1)
    # prior 
    alpha = pm.Normal("alpha", mu=a_bar, sigma=sigma_a, dims="actor")
    gamma = pm.Normal("gamma", mu=0, sigma=sigma_gamma, dims="block")
    beta = pm.Normal("beta", mu=0, sigma=0.5, dims="treatment")
    # link 
    logits = alpha[actor_idx] + gamma[block_idx] + beta[treat_idx] 
    # likelihood 
    obs = pm.Bernoulli("obs", logit_p=logits, observed=df['pulled_left'])
with chimp_model: 
    chimp_model_trace = pm.sample(draws=1000, tune=1000, progressbar=False, nuts_sampler="numpyro")
    chimp_model_trace = pm.compute_log_likelihood(chimp_model_trace, progressbar=False)
    # posterior predictive check 
    pm.sample_posterior_predictive(chimp_model_trace, extend_inferencedata=True, progressbar=False)
arviz.InferenceData
    • <xarray.Dataset> Size: 648kB
      Dimensions:      (chain: 4, draw: 1000, actor: 7, block: 6, treatment: 4)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * actor        (actor) int64 56B 1 2 3 4 5 6 7
        * block        (block) int64 48B 1 2 3 4 5 6
        * treatment    (treatment) <U3 48B '0_0' '0_1' '1_0' '1_1'
      Data variables:
          a_bar        (chain, draw) float64 32kB 1.29 -0.1633 1.699 ... 0.5741 0.7172
          alpha        (chain, draw, actor) float64 224kB -0.1703 3.876 ... 1.432
          gamma        (chain, draw, block) float64 192kB -0.2541 -0.2389 ... -0.05601
          beta         (chain, draw, treatment) float64 128kB -0.4631 ... 0.09056
          sigma_a      (chain, draw) float64 32kB 1.658 1.307 2.001 ... 1.373 1.557
          sigma_gamma  (chain, draw) float64 32kB 0.1499 0.2763 ... 0.1269 0.145
      Attributes:
          created_at:                 2026-04-08T16:14:22.581311+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              1.342947
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • actor: 7
        • block: 6
        • treatment: 4
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • actor
          (actor)
          int64
          1 2 3 4 5 6 7
          array([1, 2, 3, 4, 5, 6, 7])
        • block
          (block)
          int64
          1 2 3 4 5 6
          array([1, 2, 3, 4, 5, 6])
        • treatment
          (treatment)
          <U3
          '0_0' '0_1' '1_0' '1_1'
          array(['0_0', '0_1', '1_0', '1_1'], dtype='<U3')
        • a_bar
          (chain, draw)
          float64
          1.29 -0.1633 ... 0.5741 0.7172
          array([[ 1.28980555, -0.16325406,  1.6993176 , ...,  1.58418886,
                   0.95214374,  1.46078323],
                 [ 0.13474025,  0.24220687, -0.17378853, ...,  0.66617628,
                   0.69398325, -1.20634835],
                 [ 0.1349962 ,  1.00725742,  0.31148377, ...,  1.46182247,
                   0.417582  ,  0.90985359],
                 [ 0.42807464, -0.10355077,  0.28203067, ...,  0.19940156,
                   0.57406477,  0.71718434]], shape=(4, 1000))
        • alpha
          (chain, draw, actor)
          float64
          -0.1703 3.876 ... 0.5049 1.432
          array([[[-0.17034963,  3.87592019, -0.65647954, ...,  0.12031377,
                    0.92421763,  1.56503929],
                  [-0.11570871,  3.98128969, -0.64074898, ..., -0.56302385,
                    0.48065397,  2.29961357],
                  [-0.25385334,  5.08927239, -0.30118927, ...,  0.1542628 ,
                    0.95470879,  1.98733182],
                  ...,
                  [ 0.21757601,  4.69802435,  0.03755782, ..., -0.08624918,
                    0.79583537,  2.04667448],
                  [-0.09287817,  3.49992813, -0.24321449, ..., -0.40905712,
                    0.37917564,  1.73438105],
                  [-0.21798557,  5.94883403, -0.40140522, ..., -0.08793162,
                    0.93646566,  2.46706695]],
          
                 [[-0.18356402,  3.50032524, -0.75974201, ..., -0.41284834,
                    0.72289963,  1.87489104],
                  [-0.05917416,  3.42147617, -0.65188843, ..., -0.44736276,
                    0.75535211,  1.80791264],
                  [ 0.06630216,  3.5306832 , -0.8306595 , ..., -0.30292239,
                    1.01556627,  2.12693902],
          ...
                  [-0.62103808,  3.92475777, -0.79659264, ..., -0.52620823,
                    0.01432753,  2.31413181],
                  [-0.15735839,  4.74489878, -0.36569123, ..., -0.09167135,
                    1.15424493,  1.72849945],
                  [-0.35022429,  4.08230203, -0.91420036, ..., -0.39228282,
                    0.1808021 ,  2.73899196]],
          
                 [[-0.76424616,  5.4378743 , -0.78149291, ..., -0.86857178,
                    0.73917059,  1.38011997],
                  [-0.48287371,  4.30221248, -0.47310291, ..., -0.24975318,
                    0.45428322,  1.64706777],
                  [-0.44826971,  4.06948244, -0.51032011, ..., -0.23534417,
                    0.25004931,  1.64122855],
                  ...,
                  [-0.73935475,  3.79309242, -0.6633534 , ..., -0.24873689,
                    0.89596525,  2.20336247],
                  [-0.20159286,  4.0038474 , -0.53754139, ..., -0.84398626,
                    0.45086891,  2.42278261],
                  [-0.12803123,  3.67653962, -0.66339331, ...,  0.01071679,
                    0.50486146,  1.43197366]]], shape=(4, 1000, 7))
        • gamma
          (chain, draw, block)
          float64
          -0.2541 -0.2389 ... -0.05601
          array([[[-2.54059442e-01, -2.38862824e-01,  1.29738339e-01,
                   -3.10022204e-02, -4.13929250e-02,  1.42737080e-01],
                  [-2.69063801e-01,  6.36327325e-02, -3.08563646e-02,
                   -1.05311003e-01,  2.43564125e-01,  1.81361627e-01],
                  [-9.66393108e-02, -1.01724372e-01,  4.27659053e-02,
                    5.29424909e-02, -2.56304293e-01, -1.92574616e-02],
                  ...,
                  [-4.09237122e-01, -1.98239009e-02, -1.39355357e-01,
                   -1.95054335e-01,  8.90635213e-02,  9.89102577e-02],
                  [-1.32303072e-01, -1.56171565e-02, -1.55430736e-02,
                   -1.05900806e-01,  7.43111694e-02,  2.23318822e-02],
                  [-2.15855696e-01,  7.73250078e-02,  7.76484333e-02,
                   -1.61468335e-01,  9.31784271e-03,  1.95893912e-01]],
          
                 [[-2.93555019e-01,  2.66527388e-02, -2.19169551e-01,
                   -5.56424437e-02, -3.64275507e-01,  3.34748678e-02],
                  [-2.23613487e-01,  9.61269704e-02, -5.75881060e-02,
                   -6.21163890e-02, -3.43156226e-01,  1.19456029e-01],
                  [-1.81889661e-01,  1.20520452e-01, -1.32203889e-02,
                   -1.26511451e-01, -3.00486599e-01,  8.08296260e-02],
          ...
                   -1.41061153e-01,  4.94461355e-01,  1.43291556e-01],
                  [-3.39760581e-01, -2.97820376e-01,  1.59548509e-01,
                    2.05015853e-01, -6.20372599e-01,  4.24990415e-02],
                  [-2.15243108e-01,  4.63087152e-01,  1.61024825e-01,
                    1.57977922e-01,  3.44361347e-01,  2.41084887e-01]],
          
                 [[-3.37523865e-02, -3.55490606e-03,  5.72325571e-02,
                    6.07425739e-02,  1.77351532e-02, -1.88196823e-02],
                  [-5.46180945e-02, -1.54002384e-05,  1.82450035e-02,
                   -1.08411821e-01,  3.02322818e-02,  2.13759990e-03],
                  [ 1.86227935e-02, -1.12772750e-02,  2.04806002e-02,
                    3.44772850e-02,  3.40487507e-03,  1.48647986e-02],
                  ...,
                  [-1.84215780e-02,  2.52840123e-02,  1.09275371e-02,
                   -5.82718189e-02, -7.37372664e-02, -1.56811396e-01],
                  [-1.40634162e-01,  1.85638367e-02,  9.14661337e-02,
                    2.14476714e-01,  1.09806669e-01,  1.35125059e-01],
                  [ 3.74574452e-02, -1.73965139e-02,  9.17600456e-02,
                    9.78358863e-02, -7.15364790e-02, -5.60105827e-02]]],
                shape=(4, 1000, 6))
        • beta
          (chain, draw, treatment)
          float64
          -0.4631 0.1311 ... -0.6491 0.09056
          array([[[-0.46306978,  0.13109296, -0.09871006,  0.40953963],
                  [-0.22430527,  0.26526698, -0.66296545,  0.0658827 ],
                  [-0.41870282,  0.09632172, -0.52556176,  0.02499108],
                  ...,
                  [-0.3938906 , -0.07541524, -0.94862161,  0.23211918],
                  [-0.0358297 ,  0.32429908, -0.5687316 ,  0.65329142],
                  [-0.48269213,  0.18857743, -0.8098052 ,  0.26385822]],
          
                 [[-0.40744667,  0.54587601, -0.35196655,  0.04176237],
                  [-0.23022458,  0.53515005, -0.27357434,  0.14334922],
                  [-0.33947383,  0.7634544 , -0.32752955,  0.3841556 ],
                  ...,
                  [-0.48016633,  0.0731252 , -1.01733702, -0.19874379],
                  [-0.57254874,  0.0806847 , -0.47945971,  0.17256363],
                  [-0.00147789,  0.40427578, -0.73029758,  0.04933755]],
          
                 [[ 0.04515337,  0.55574038, -0.18037644,  0.37884146],
                  [-0.13463688, -0.05308604, -0.58576575,  0.0852217 ],
                  [ 0.09925095,  0.30199833, -0.7521817 , -0.01769447],
                  ...,
                  [-0.38909279,  0.62890802, -0.55241562,  0.44184265],
                  [ 0.0271466 ,  0.07350349, -0.60777348, -0.06115244],
                  [-0.54095468,  0.44134211, -0.58096351,  0.34015604]],
          
                 [[-0.17661854,  0.75156201, -0.1338684 ,  0.16846612],
                  [-0.17107222,  0.62374549, -0.47089054,  0.23467154],
                  [-0.27462615,  0.69160988, -0.3467264 ,  0.12706091],
                  ...,
                  [ 0.07397737,  0.33987709, -0.35744212,  0.17356588],
                  [-0.01593897,  0.22731412, -0.43381563,  0.22988675],
                  [-0.04198254,  0.16976248, -0.64908842,  0.09056253]]],
                shape=(4, 1000, 4))
        • sigma_a
          (chain, draw)
          float64
          1.658 1.307 2.001 ... 1.373 1.557
          array([[1.65790351, 1.30690372, 2.00141904, ..., 1.76083268, 1.42097701,
                  2.18299885],
                 [1.64056374, 1.66238627, 1.61858992, ..., 2.84643044, 3.54298447,
                  1.87401071],
                 [2.48932714, 1.33970885, 1.31850335, ..., 1.44578176, 2.48877   ,
                  1.38275425],
                 [2.03586791, 2.05508773, 1.99453533, ..., 1.8132041 , 1.37254294,
                  1.55735997]], shape=(4, 1000))
        • sigma_gamma
          (chain, draw)
          float64
          0.1499 0.2763 ... 0.1269 0.145
          array([[0.14993819, 0.27625698, 0.12294576, ..., 0.18689887, 0.09545587,
                  0.28422865],
                 [0.26111354, 0.1934832 , 0.18277464, ..., 0.55924908, 0.15251212,
                  0.23472662],
                 [0.16505646, 0.26052831, 0.17661903, ..., 0.33841522, 0.41337582,
                  0.22605984],
                 [0.05465101, 0.05738951, 0.03992355, ..., 0.07588328, 0.12685128,
                  0.14504564]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:22.581311+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        1.342947
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 16MB
      Dimensions:    (chain: 4, draw: 1000, obs_dim_0: 504)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (chain, draw, obs_dim_0) int64 16MB 0 0 1 0 0 1 0 ... 1 1 0 0 1 1
      Attributes:
          created_at:                 2026-04-08T16:14:45.526219+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • obs_dim_0: 504
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (chain, draw, obs_dim_0)
          int64
          0 0 1 0 0 1 0 1 ... 0 1 1 1 0 0 1 1
          array([[[0, 0, 1, ..., 1, 1, 1],
                  [0, 0, 1, ..., 0, 1, 1],
                  [0, 0, 1, ..., 1, 1, 1],
                  ...,
                  [0, 1, 1, ..., 1, 1, 1],
                  [0, 1, 0, ..., 1, 1, 1],
                  [0, 0, 1, ..., 0, 1, 1]],
          
                 [[0, 0, 1, ..., 1, 0, 1],
                  [0, 1, 0, ..., 1, 1, 0],
                  [1, 1, 0, ..., 0, 1, 1],
                  ...,
                  [0, 0, 0, ..., 1, 1, 0],
                  [0, 0, 0, ..., 1, 1, 1],
                  [0, 0, 1, ..., 1, 1, 1]],
          
                 [[1, 0, 0, ..., 1, 0, 1],
                  [0, 0, 1, ..., 1, 1, 1],
                  [0, 0, 1, ..., 0, 1, 1],
                  ...,
                  [0, 0, 0, ..., 1, 1, 1],
                  [0, 0, 0, ..., 1, 1, 0],
                  [0, 0, 0, ..., 1, 1, 1]],
          
                 [[0, 0, 0, ..., 1, 1, 1],
                  [1, 0, 0, ..., 1, 1, 1],
                  [0, 1, 0, ..., 1, 1, 1],
                  ...,
                  [0, 0, 1, ..., 1, 1, 1],
                  [0, 1, 1, ..., 1, 1, 1],
                  [0, 0, 1, ..., 0, 1, 1]]], shape=(4, 1000, 504))
      • created_at :
        2026-04-08T16:14:45.526219+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 16MB
      Dimensions:    (chain: 4, draw: 1000, obs_dim_0: 504)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 16MB -0.3448 -1.232 ... -0.3943
      Attributes:
          created_at:                 2026-04-08T16:14:42.114081+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • obs_dim_0: 504
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (chain, draw, obs_dim_0)
          float64
          -0.3448 -1.232 ... -0.3943 -0.3943
          array([[[-0.34478925, -1.23226811, -0.55720508, ..., -0.1823835 ,
                   -0.1823835 , -0.1823835 ],
                  [-0.43428071, -1.0433585 , -0.63517855, ..., -0.15044257,
                   -0.15044257, -0.15044257],
                  [-0.38075274, -1.14994821, -0.57411543, ..., -0.2121499 ,
                   -0.2121499 , -0.2121499 ],
                  ...,
                  [-0.44263152, -1.02818324, -0.56849885, ..., -0.26398625,
                   -0.26398625, -0.26398625],
                  [-0.57113349, -0.83214443, -0.74393364, ..., -0.26607737,
                   -0.26607737, -0.26607737],
                  [-0.33640291, -1.25293631, -0.57801579, ..., -0.14560723,
                   -0.14560723, -0.14560723]],
          
                 [[-0.34563969, -1.23020541, -0.72811649, ..., -0.19135898,
                   -0.19135898, -0.19135898],
                  [-0.46918421, -0.98219644, -0.82726819, ..., -0.17506411,
                   -0.17506411, -0.17506411],
                  [-0.49128131, -0.94634264, -1.06865438, ..., -0.14198003,
                   -0.14198003, -0.14198003],
          ...
                  [-0.28035412, -1.40860604, -0.63954027, ..., -0.13873663,
                   -0.13873663, -0.13873663],
                  [-0.48551984, -0.95549221, -0.50360496, ..., -0.27191605,
                   -0.27191605, -0.27191605],
                  [-0.28573534, -1.39215741, -0.63300919, ..., -0.08690993,
                   -0.08690993, -0.08690993]],
          
                 [[-0.32015179, -1.29476888, -0.67019843, ..., -0.25699927,
                   -0.25699927, -0.25699927],
                  [-0.40035252, -1.10891655, -0.7372037 , ..., -0.26834404,
                   -0.26834404, -0.26834404],
                  [-0.40177022, -1.10604328, -0.83268233, ..., -0.23900975,
                   -0.23900975, -0.23900975],
                  ...,
                  [-0.4085909 , -1.09238986, -0.50587051, ..., -0.16947608,
                   -0.16947608, -0.16947608],
                  [-0.53001456, -0.88818055, -0.63734044, ..., -0.11291919,
                   -0.11291919, -0.11291919],
                  [-0.62906381, -0.76162014, -0.73352518, ..., -0.39434853,
                   -0.39434853, -0.39434853]]], shape=(4, 1000, 504))
      • created_at :
        2026-04-08T16:14:42.114081+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.9782 0.9409 ... 0.9527 0.9762
          step_size        (chain, draw) float64 32kB 0.31 0.31 0.31 ... 0.1721 0.1721
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 291.5 289.0 ... 282.3 284.4
          n_steps          (chain, draw) int64 32kB 15 7 15 7 23 31 ... 31 31 15 15 15
          tree_depth       (chain, draw) int64 32kB 4 3 4 3 5 5 3 5 ... 5 4 5 5 4 4 4
          lp               (chain, draw) float64 32kB 282.0 278.7 ... 276.2 276.4
      Attributes:
          created_at:     2026-04-08T16:14:22.584212+00:00
          arviz_version:  0.22.0
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • acceptance_rate
          (chain, draw)
          float64
          0.9782 0.9409 ... 0.9527 0.9762
          array([[0.97822891, 0.94089979, 0.92393134, ..., 0.93882981, 0.91814551,
                  0.77295873],
                 [0.9121912 , 1.        , 0.82574649, ..., 0.72766688, 0.94798829,
                  0.98753537],
                 [0.96449547, 0.93888156, 0.94482671, ..., 0.86305988, 0.98046234,
                  0.99957689],
                 [0.53098673, 0.84466115, 0.71683792, ..., 0.98540029, 0.9526505 ,
                  0.97622852]], shape=(4, 1000))
        • step_size
          (chain, draw)
          float64
          0.31 0.31 0.31 ... 0.1721 0.1721
          array([[0.31004679, 0.31004679, 0.31004679, ..., 0.31004679, 0.31004679,
                  0.31004679],
                 [0.26398952, 0.26398952, 0.26398952, ..., 0.26398952, 0.26398952,
                  0.26398952],
                 [0.23888905, 0.23888905, 0.23888905, ..., 0.23888905, 0.23888905,
                  0.23888905],
                 [0.17208184, 0.17208184, 0.17208184, ..., 0.17208184, 0.17208184,
                  0.17208184]], shape=(4, 1000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 1000))
        • energy
          (chain, draw)
          float64
          291.5 289.0 290.3 ... 282.3 284.4
          array([[291.45281913, 289.02503497, 290.33607125, ..., 284.77141795,
                  289.26973522, 285.96331696],
                 [291.6506182 , 282.9545554 , 286.81086298, ..., 292.61116409,
                  296.48262838, 290.57933235],
                 [287.2435928 , 294.36995115, 293.99705208, ..., 292.70173869,
                  294.29003724, 291.04840465],
                 [288.47526922, 287.61795749, 281.62311212, ..., 286.29877759,
                  282.25735468, 284.38280553]], shape=(4, 1000))
        • n_steps
          (chain, draw)
          int64
          15 7 15 7 23 31 ... 31 31 15 15 15
          array([[15,  7, 15, ..., 15, 39, 23],
                 [15,  7,  7, ..., 15, 15, 15],
                 [15, 15,  7, ..., 15, 15, 31],
                 [15, 15, 15, ..., 15, 15, 15]], shape=(4, 1000))
        • tree_depth
          (chain, draw)
          int64
          4 3 4 3 5 5 3 5 ... 6 5 4 5 5 4 4 4
          array([[4, 3, 4, ..., 4, 6, 5],
                 [4, 3, 3, ..., 4, 4, 4],
                 [4, 4, 3, ..., 4, 4, 5],
                 [4, 4, 4, ..., 4, 4, 4]], shape=(4, 1000))
        • lp
          (chain, draw)
          float64
          282.0 278.7 276.3 ... 276.2 276.4
          array([[282.00301433, 278.66306748, 276.27318402, ..., 280.06677579,
                  275.71083174, 278.14016881],
                 [280.06131695, 276.71844921, 282.0999004 , ..., 285.37919704,
                  283.11798318, 282.31572853],
                 [280.08717287, 284.69824581, 277.18244008, ..., 285.2341203 ,
                  285.31765076, 282.44182194],
                 [275.33411215, 272.99486537, 269.56675331, ..., 274.8378058 ,
                  276.15172353, 276.38626236]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:22.584212+00:00
        arviz_version :
        0.22.0

    • <xarray.Dataset> Size: 8kB
      Dimensions:    (obs_dim_0: 504)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (obs_dim_0) int64 4kB 0 1 0 0 1 1 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1
      Attributes:
          created_at:                 2026-04-08T16:14:22.584667+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              1.342947
          tuning_steps:               1000
      xarray.Dataset
        • obs_dim_0: 504
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (obs_dim_0)
          int64
          0 1 0 0 1 1 0 0 ... 1 1 1 1 1 1 1 1
          array([0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0,
                 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
                 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1,
                 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
                 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1,
                 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
                 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
                 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0,
                 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
                 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
                 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0,
                 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
                 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
                 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
                 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,
                 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
                 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1,
                 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
                 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
      • created_at :
        2026-04-08T16:14:22.584667+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        1.342947
        tuning_steps :
        1000

The above would produce some warning about divergences, which is to be expected, given than we have not implemeneted non-centered parameterization. The actual correct model would be the following:

with pm.Model(coords=coords) as chimp_model_2:
    # hyperpriors 
    a_bar = pm.Normal("a_bar", 0, 1.5)
    sigma_a = pm.Exponential("sigma_a", 1)
    sigma_gamma = pm.Exponential("sigma_gamma", 1)
    # prior 
    z = pm.Normal("z", mu=0, sigma=1, dims="actor")
    x = pm.Normal("x", mu=0, sigma=1, dims="block")
    beta = pm.Normal("beta", mu=0, sigma=0.5, dims="treatment")
    # link 
    logits = a_bar + z[actor_idx] * sigma_a + x[block_idx] * sigma_gamma + beta[treat_idx] 
    # likelihood 
    obs = pm.Bernoulli("obs", logit_p=logits, observed=df['pulled_left'])
with chimp_model_2: 
    chimp_model_trace = pm.sample(draws=1000, tune=1000, progressbar=False, nuts_sampler="numpyro")
    chimp_model_trace = pm.compute_log_likelihood(chimp_model_trace, progressbar=False)
    # posterior predictive check 
    pm.sample_posterior_predictive(chimp_model_trace, extend_inferencedata=True, progressbar=False)
arviz.InferenceData
    • <xarray.Dataset> Size: 648kB
      Dimensions:      (chain: 4, draw: 1000, actor: 7, block: 6, treatment: 4)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * actor        (actor) int64 56B 1 2 3 4 5 6 7
        * block        (block) int64 48B 1 2 3 4 5 6
        * treatment    (treatment) <U3 48B '0_0' '0_1' '1_0' '1_1'
      Data variables:
          a_bar        (chain, draw) float64 32kB -0.02088 -0.2046 ... 1.65 1.794
          z            (chain, draw, actor) float64 224kB -0.1753 3.071 ... 0.05749
          x            (chain, draw, block) float64 192kB -0.1712 -0.3026 ... 0.5271
          beta         (chain, draw, treatment) float64 128kB -0.1789 ... 0.1652
          sigma_a      (chain, draw) float64 32kB 1.889 1.955 1.474 ... 1.745 1.579
          sigma_gamma  (chain, draw) float64 32kB 0.09725 0.09974 ... 0.196 0.151
      Attributes:
          created_at:                 2026-04-08T16:14:47.852688+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              1.814978
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • actor: 7
        • block: 6
        • treatment: 4
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • actor
          (actor)
          int64
          1 2 3 4 5 6 7
          array([1, 2, 3, 4, 5, 6, 7])
        • block
          (block)
          int64
          1 2 3 4 5 6
          array([1, 2, 3, 4, 5, 6])
        • treatment
          (treatment)
          <U3
          '0_0' '0_1' '1_0' '1_1'
          array(['0_0', '0_1', '1_0', '1_1'], dtype='<U3')
        • a_bar
          (chain, draw)
          float64
          -0.02088 -0.2046 ... 1.65 1.794
          array([[-0.02088421, -0.20460832,  0.66374982, ...,  1.18555704,
                   0.6362816 ,  0.52039248],
                 [ 0.05220255,  0.44558477,  0.51411673, ...,  1.76014326,
                   0.73847268,  0.42494288],
                 [-0.27319101, -0.19004832, -0.31901737, ...,  1.4269986 ,
                   0.94345355,  0.73722432],
                 [ 0.57067942,  0.39451598,  0.6362998 , ...,  2.59405792,
                   1.64972996,  1.79410366]], shape=(4, 1000))
        • z
          (chain, draw, actor)
          float64
          -0.1753 3.071 ... -0.5424 0.05749
          array([[[-1.75310474e-01,  3.07134621e+00, -1.35350884e-01, ...,
                    1.12226898e-02,  4.35594360e-01,  1.29710511e+00],
                  [-3.48017837e-01,  3.20692745e+00, -2.72688195e-01, ...,
                    1.81624401e-01,  4.69514025e-01,  1.36970674e+00],
                  [-1.03947018e+00,  1.80832471e+00, -9.39259483e-01, ...,
                   -4.45012325e-01,  1.18163632e-01,  9.60282664e-01],
                  ...,
                  [-1.19047281e+00,  2.20442308e+00, -1.47857451e+00, ...,
                   -1.20551876e+00, -4.53432732e-01,  7.06464720e-01],
                  [-4.16571038e-01,  2.35797319e+00, -9.78049144e-01, ...,
                   -5.39136786e-01, -4.34141580e-03,  5.72691902e-01],
                  [-8.24003969e-01,  2.06105517e+00, -6.65121381e-01, ...,
                   -6.36577696e-01,  2.85988213e-02,  1.20408842e+00]],
          
                 [[-7.55194553e-01,  2.43023817e+00, -6.52476572e-01, ...,
                   -5.37137545e-01,  6.27358053e-01,  1.25836007e+00],
                  [-3.72598271e-01,  3.25386073e+00, -9.74934113e-01, ...,
                   -6.14432392e-01,  1.20391596e-04,  1.82094311e+00],
                  [-6.28665281e-01,  2.75766313e+00, -6.42131553e-01, ...,
                   -5.12309402e-01,  4.22855632e-01,  1.03108726e+00],
          ...
                   -8.57567291e-01, -4.02525781e-01,  1.90126866e-01],
                  [-5.12527006e-01,  2.52504515e+00, -5.89946148e-01, ...,
                   -5.79497604e-01, -8.08371257e-02,  4.43370844e-01],
                  [-3.64009155e-01,  2.28303321e+00, -8.28002722e-01, ...,
                   -5.27351483e-01, -6.88015587e-02,  5.58933094e-01]],
          
                 [[-4.47243102e-01,  1.58915377e+00, -4.44401890e-01, ...,
                   -6.36559523e-01,  1.12683065e-01,  1.33909481e+00],
                  [-4.86634435e-01,  3.60025492e+00, -5.30973674e-01, ...,
                   -4.54610256e-01, -1.91610751e-01,  1.15545905e+00],
                  [-6.58927384e-01,  2.61425025e+00, -5.89643519e-01, ...,
                   -3.27054950e-01, -2.52176513e-01,  1.07492798e+00],
                  ...,
                  [-1.11226763e+00,  1.45155811e+00, -1.33134157e+00, ...,
                   -1.25070306e+00, -7.64878734e-01, -3.35708462e-01],
                  [-1.48422211e+00,  1.06081968e+00, -1.39388446e+00, ...,
                   -1.39200987e+00, -8.62663086e-01,  1.79673332e-01],
                  [-1.43212545e+00,  9.02412742e-01, -1.53100194e+00, ...,
                   -1.40946593e+00, -5.42441942e-01,  5.74908247e-02]]],
                shape=(4, 1000, 7))
        • x
          (chain, draw, block)
          float64
          -0.1712 -0.3026 ... -0.1899 0.5271
          array([[[-0.17122261, -0.30262022, -0.23986053, -0.09606596,
                    0.97658446, -0.00994608],
                  [ 0.68839527, -0.3417134 , -0.2779389 , -0.37818255,
                    0.62197395,  1.08674979],
                  [-0.53634396, -0.49587753,  0.33745516,  0.48720998,
                   -0.95729414,  0.56672252],
                  ...,
                  [-1.55071215, -2.39212905, -0.57804512,  0.18442189,
                   -1.18691837, -1.00914486],
                  [ 0.67905283,  1.63780787,  0.02342315, -0.36421463,
                    0.18230448,  0.08482443],
                  [-1.11764743, -1.25626654, -0.83051611, -1.11416377,
                    0.73707185, -1.36180319]],
          
                 [[-0.21942798,  1.32398278,  0.54198254,  0.07907589,
                   -0.74705624,  1.28953882],
                  [ 0.73393906,  0.68260669,  0.26788079,  0.43065464,
                    0.70283282,  1.95767039],
                  [-1.42304054,  0.50500775,  0.2034807 , -0.08963841,
                   -0.67343866,  0.30276461],
          ...
                  [-1.26236397,  1.10386691, -0.39122273, -0.47324575,
                    0.34569335, -0.7700633 ],
                  [-1.43347564,  0.85873548, -0.15538574,  0.94389985,
                   -0.57855576,  0.07666737],
                  [-1.11174123,  0.87992382, -0.40889956,  0.6368727 ,
                   -1.04503538,  0.0572967 ]],
          
                 [[-1.58502045,  0.37806034,  0.10879   , -0.26131479,
                   -0.01140733,  1.21901297],
                  [ 0.08327148, -0.43827691,  0.56216264,  0.55059151,
                   -0.92268544, -0.51978087],
                  [-2.00627468,  0.95682254, -0.59817015, -1.21904924,
                    1.19438847,  1.24915124],
                  ...,
                  [-0.80287646,  1.77944514, -0.45033259,  0.26006346,
                    1.55211504,  0.29725897],
                  [-2.20908191,  0.75023816, -0.40750326, -0.66347147,
                    0.17073449,  0.880077  ],
                  [-1.84683135,  0.71423839, -0.76447095, -0.06943259,
                   -0.18993784,  0.52714199]]], shape=(4, 1000, 6))
        • beta
          (chain, draw, treatment)
          float64
          -0.1789 -0.1215 ... -0.07959 0.1652
          array([[[-1.78863986e-01, -1.21531167e-01, -6.15819308e-01,
                    1.22936926e-01],
                  [-2.32332017e-01,  5.86146720e-02, -6.74486449e-01,
                    2.84195503e-01],
                  [-1.92302214e-01,  4.42827675e-01, -6.03201709e-01,
                   -2.45928239e-04],
                  ...,
                  [-1.76529324e-01,  5.05122487e-01, -6.69477495e-01,
                    3.02737439e-01],
                  [ 3.04798883e-03,  3.18216436e-01, -5.71480118e-02,
                    2.38218952e-01],
                  [ 1.66508467e-01,  8.16860132e-01, -5.49634317e-01,
                    4.36687575e-01]],
          
                 [[ 9.20866150e-02,  6.84937633e-01, -3.90721157e-01,
                    4.61751285e-01],
                  [-2.84522605e-01,  5.85400057e-02, -4.97134419e-01,
                   -6.11199006e-02],
                  [-2.99670346e-01,  5.05860627e-01, -7.05012909e-01,
                    4.55592102e-01],
          ...
                  [ 6.34715382e-02,  6.09273043e-01, -2.20570412e-01,
                   -2.85705791e-02],
                  [-2.87492386e-01,  9.33307530e-02, -4.92203066e-01,
                    6.18235111e-01],
                  [-7.79666709e-02,  4.05199061e-01, -3.72669676e-01,
                    4.78262152e-01]],
          
                 [[-5.88652726e-01,  5.40752296e-01, -7.41506777e-01,
                    6.95975197e-02],
                  [ 1.36665252e-01,  1.28341934e-01, -4.12567167e-01,
                    2.32824564e-01],
                  [-4.53525632e-01,  7.11389517e-01, -8.92619316e-01,
                    5.75440502e-01],
                  ...,
                  [ 1.32046594e-01,  6.73439795e-01, -4.04446403e-01,
                    4.85932896e-01],
                  [ 8.52742964e-02,  5.68309016e-01,  7.54389718e-02,
                    1.26205793e-01],
                  [ 1.59148519e-01,  7.90267222e-01, -7.95928925e-02,
                    1.65222893e-01]]], shape=(4, 1000, 4))
        • sigma_a
          (chain, draw)
          float64
          1.889 1.955 1.474 ... 1.745 1.579
          array([[1.88884116, 1.95451233, 1.47406322, ..., 1.0734052 , 2.28307432,
                  1.3847739 ],
                 [1.40275709, 1.08171031, 1.32582226, ..., 2.39247078, 1.34830203,
                  1.32436038],
                 [2.00936654, 2.18713724, 2.4869811 , ..., 2.41897457, 2.16980719,
                  2.00794561],
                 [1.71646938, 1.53025428, 1.9439071 , ..., 2.71518611, 1.74547618,
                  1.57896305]], shape=(4, 1000))
        • sigma_gamma
          (chain, draw)
          float64
          0.09725 0.09974 ... 0.196 0.151
          array([[0.09724513, 0.09974351, 0.11329045, ..., 0.00670239, 0.07938163,
                  0.03661569],
                 [0.41836994, 0.08771072, 0.54269365, ..., 0.50057331, 0.3022029 ,
                  0.06743745],
                 [0.15079147, 0.13167929, 0.18044498, ..., 0.21266696, 0.16981932,
                  0.10208375],
                 [0.24929223, 0.07642743, 0.19084659, ..., 0.16773959, 0.19602679,
                  0.15101699]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:47.852688+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        1.814978
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 16MB
      Dimensions:    (chain: 4, draw: 1000, obs_dim_0: 504)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (chain, draw, obs_dim_0) int64 16MB 0 0 1 1 0 1 1 ... 1 0 1 0 1 1
      Attributes:
          created_at:                 2026-04-08T16:15:26.748320+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • obs_dim_0: 504
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (chain, draw, obs_dim_0)
          int64
          0 0 1 1 0 1 1 0 ... 1 1 1 0 1 0 1 1
          array([[[0, 0, 1, ..., 0, 1, 1],
                  [0, 0, 0, ..., 1, 1, 1],
                  [0, 0, 0, ..., 1, 1, 1],
                  ...,
                  [0, 1, 1, ..., 1, 1, 1],
                  [0, 0, 0, ..., 1, 1, 1],
                  [1, 0, 1, ..., 1, 1, 0]],
          
                 [[0, 1, 0, ..., 1, 1, 1],
                  [1, 1, 1, ..., 1, 1, 1],
                  [0, 1, 1, ..., 1, 1, 0],
                  ...,
                  [0, 0, 1, ..., 1, 1, 1],
                  [0, 1, 0, ..., 1, 1, 1],
                  [0, 0, 1, ..., 1, 1, 1]],
          
                 [[0, 1, 1, ..., 1, 1, 1],
                  [0, 1, 0, ..., 0, 0, 1],
                  [0, 1, 1, ..., 1, 1, 1],
                  ...,
                  [0, 0, 1, ..., 1, 1, 0],
                  [0, 0, 1, ..., 1, 1, 1],
                  [1, 1, 1, ..., 0, 1, 1]],
          
                 [[0, 0, 0, ..., 1, 1, 1],
                  [0, 1, 1, ..., 1, 1, 1],
                  [0, 0, 0, ..., 0, 1, 1],
                  ...,
                  [0, 0, 0, ..., 1, 1, 0],
                  [0, 0, 0, ..., 1, 1, 1],
                  [1, 1, 1, ..., 0, 1, 1]]], shape=(4, 1000, 504))
      • created_at :
        2026-04-08T16:15:26.748320+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 16MB
      Dimensions:    (chain: 4, draw: 1000, obs_dim_0: 504)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 16MB -0.4564 -1.004 ... -0.1414
      Attributes:
          created_at:                 2026-04-08T16:15:15.943723+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • obs_dim_0: 504
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (chain, draw, obs_dim_0)
          float64
          -0.4564 -1.004 ... -0.1414 -0.1414
          array([[[-0.456396  , -1.00392839, -0.47778834, ..., -0.15123433,
                   -0.15123433, -0.15123433],
                  [-0.30045207, -1.3489346 , -0.38445953, ..., -0.13855506,
                   -0.13855506, -0.13855506],
                  [-0.28199437, -1.40355418, -0.4792219 , ..., -0.19418696,
                   -0.19418696, -0.19418696],
                  ...,
                  [-0.56324881, -0.84247427, -0.91446859, ..., -0.24802777,
                   -0.24802777, -0.24802777],
                  [-0.57251936, -0.83034809, -0.72222796, ..., -0.14025249,
                   -0.14025249, -0.14025249],
                  [-0.47593669, -0.97101837, -0.77379275, ..., -0.18587826,
                   -0.18587826, -0.18587826]],
          
                 [[-0.31141939, -1.3182868 , -0.50741388, ..., -0.13102257,
                   -0.13102257, -0.13102257],
                  [-0.60828158, -0.78588849, -0.77929309, ..., -0.11662726,
                   -0.11662726, -0.11662726],
                  [-0.22213902, -1.61346614, -0.44254411, ..., -0.23248196,
                   -0.23248196, -0.23248196],
          ...
                  [-0.42616441, -1.05845636, -0.65083674, ..., -0.20094435,
                   -0.20094435, -0.20094435],
                  [-0.40333357, -1.10288904, -0.54642641, ..., -0.21526957,
                   -0.21526957, -0.21526957],
                  [-0.60485402, -0.78999767, -0.85321952, ..., -0.20274695,
                   -0.20274695, -0.20274695]],
          
                 [[-0.26776306, -1.44854873, -0.66778681, ..., -0.08424738,
                   -0.08424738, -0.08424738],
                  [-0.59493594, -0.80206491, -0.59121231, ..., -0.16619131,
                   -0.16619131, -0.16619131],
                  [-0.20490451, -1.68591466, -0.54753754, ..., -0.11865323,
                   -0.11865323, -0.11865323],
                  ...,
                  [-0.50164239, -0.93022567, -0.75114206, ..., -0.23508528,
                   -0.23508528, -0.23508528],
                  [-0.24343727, -1.53214659, -0.36934484, ..., -0.10396155,
                   -0.10396155, -0.10396155],
                  [-0.44214097, -1.02906487, -0.71548871, ..., -0.14136513,
                   -0.14136513, -0.14136513]]], shape=(4, 1000, 504))
      • created_at :
        2026-04-08T16:15:15.943723+00:00
        arviz_version :
        0.22.0
        inference_library :
        pymc
        inference_library_version :
        5.26.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.8885 0.8705 ... 0.8175 0.8218
          step_size        (chain, draw) float64 32kB 0.1313 0.1313 ... 0.1144 0.1144
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 293.8 296.9 ... 298.6 298.5
          n_steps          (chain, draw) int64 32kB 31 31 31 15 31 ... 31 31 31 31 31
          tree_depth       (chain, draw) int64 32kB 5 5 5 4 5 4 5 5 ... 5 5 5 5 5 5 5
          lp               (chain, draw) float64 32kB 284.2 288.8 ... 289.0 289.5
      Attributes:
          created_at:     2026-04-08T16:14:47.855882+00:00
          arviz_version:  0.22.0
      xarray.Dataset
        • chain: 4
        • draw: 1000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 5 ... 995 996 997 998 999
          array([  0,   1,   2, ..., 997, 998, 999], shape=(1000,))
        • acceptance_rate
          (chain, draw)
          float64
          0.8885 0.8705 ... 0.8175 0.8218
          array([[0.88848698, 0.87052317, 0.93229675, ..., 0.98739939, 0.96857345,
                  0.86990064],
                 [0.98404646, 0.99782274, 0.97865357, ..., 0.90347391, 0.99479631,
                  0.99034699],
                 [0.96495186, 0.86828782, 0.95583907, ..., 0.9323464 , 0.80320889,
                  0.97959749],
                 [0.9741557 , 0.79360727, 0.98012525, ..., 0.99275159, 0.81750842,
                  0.8217966 ]], shape=(4, 1000))
        • step_size
          (chain, draw)
          float64
          0.1313 0.1313 ... 0.1144 0.1144
          array([[0.13130084, 0.13130084, 0.13130084, ..., 0.13130084, 0.13130084,
                  0.13130084],
                 [0.1168375 , 0.1168375 , 0.1168375 , ..., 0.1168375 , 0.1168375 ,
                  0.1168375 ],
                 [0.12311437, 0.12311437, 0.12311437, ..., 0.12311437, 0.12311437,
                  0.12311437],
                 [0.11439194, 0.11439194, 0.11439194, ..., 0.11439194, 0.11439194,
                  0.11439194]], shape=(4, 1000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 1000))
        • energy
          (chain, draw)
          float64
          293.8 296.9 295.5 ... 298.6 298.5
          array([[293.75747593, 296.89678683, 295.46356469, ..., 304.00739084,
                  307.48824595, 303.47341556],
                 [295.624886  , 295.69716237, 295.93730846, ..., 292.18834745,
                  289.18950433, 291.1334697 ],
                 [300.11788703, 300.13592164, 294.28718779, ..., 294.16554055,
                  292.96637564, 289.95497484],
                 [303.4769693 , 297.32696839, 307.73905554, ..., 295.6357879 ,
                  298.62958766, 298.53550503]], shape=(4, 1000))
        • n_steps
          (chain, draw)
          int64
          31 31 31 15 31 ... 31 31 31 31 31
          array([[31, 31, 31, ..., 31, 31, 31],
                 [31, 31, 31, ..., 15, 47, 15],
                 [31, 31, 31, ..., 31, 31, 31],
                 [31, 31, 31, ..., 31, 31, 31]], shape=(4, 1000))
        • tree_depth
          (chain, draw)
          int64
          5 5 5 4 5 4 5 5 ... 4 5 5 5 5 5 5 5
          array([[5, 5, 5, ..., 5, 5, 5],
                 [5, 5, 5, ..., 4, 6, 4],
                 [5, 5, 5, ..., 5, 5, 5],
                 [5, 5, 5, ..., 5, 5, 5]], shape=(4, 1000))
        • lp
          (chain, draw)
          float64
          284.2 288.8 286.2 ... 289.0 289.5
          array([[284.17208039, 288.82266289, 286.23436499, ..., 294.16660292,
                  288.77070126, 289.90693295],
                 [286.71827211, 289.18931351, 283.27225262, ..., 282.06736363,
                  285.40078166, 283.23951961],
                 [284.91265941, 285.84708296, 287.37305597, ..., 284.75496   ,
                  283.70509139, 282.53288534],
                 [285.98829462, 288.72940182, 293.70082465, ..., 288.42172596,
                  289.02428157, 289.54400203]], shape=(4, 1000))
      • created_at :
        2026-04-08T16:14:47.855882+00:00
        arviz_version :
        0.22.0

    • <xarray.Dataset> Size: 8kB
      Dimensions:    (obs_dim_0: 504)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 4kB 0 1 2 3 4 5 6 ... 498 499 500 501 502 503
      Data variables:
          obs        (obs_dim_0) int64 4kB 0 1 0 0 1 1 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1
      Attributes:
          created_at:                 2026-04-08T16:14:47.856385+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              1.814978
          tuning_steps:               1000
      xarray.Dataset
        • obs_dim_0: 504
        • obs_dim_0
          (obs_dim_0)
          int64
          0 1 2 3 4 5 ... 499 500 501 502 503
          array([  0,   1,   2, ..., 501, 502, 503], shape=(504,))
        • obs
          (obs_dim_0)
          int64
          0 1 0 0 1 1 0 0 ... 1 1 1 1 1 1 1 1
          array([0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0,
                 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
                 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1,
                 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
                 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1,
                 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
                 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
                 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0,
                 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
                 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
                 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0,
                 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
                 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
                 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1,
                 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,
                 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
                 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1,
                 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
                 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
      • created_at :
        2026-04-08T16:14:47.856385+00:00
        arviz_version :
        0.22.0
        inference_library :
        numpyro
        inference_library_version :
        0.19.0
        sampling_time :
        1.814978
        tuning_steps :
        1000

Examine the fitted model:

  • R
  • Python
summary(m3)
 Family: binomial 
  Links: mu = logit 
Formula: pulled_left | trials(1) ~ a + b 
         a ~ 1 + (1 | actor) + (1 | block)
         b ~ 0 + treatment
   Data: df (Number of observations: 504) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~actor (Number of levels: 7) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     2.03      0.65     1.10     3.65 1.00     1117     1805

~block (Number of levels: 6) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     0.20      0.17     0.01     0.63 1.00     1619     1633

Regression Coefficients:
               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
a_Intercept        0.56      0.73    -0.87     2.04 1.00      985     1332
b_treatment0_0    -0.13      0.31    -0.73     0.45 1.00     2335     2802
b_treatment0_1     0.39      0.30    -0.20     1.00 1.00     2312     2567
b_treatment1_0    -0.47      0.30    -1.06     0.11 1.00     2236     2647
b_treatment1_1     0.28      0.30    -0.31     0.86 1.00     2331     3077

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
plot(m3)

az.summary(chimp_model_trace)
              mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
a_bar        0.647  0.719  -0.778    1.894  ...    0.013    1026.0    1808.0    1.0
z[1]        -0.531  0.389  -1.232    0.193  ...    0.007     961.0    1730.0    1.0
z[2]         2.107  0.658   0.863    3.354  ...    0.011    1805.0    2122.0    1.0
z[3]        -0.694  0.407  -1.440    0.048  ...    0.008     958.0    1663.0    1.0
z[4]        -0.697  0.408  -1.438    0.066  ...    0.007     965.0    1770.0    1.0
z[5]        -0.532  0.384  -1.228    0.204  ...    0.007     943.0    1711.0    1.0
z[6]        -0.027  0.363  -0.759    0.608  ...    0.007    1015.0    1766.0    1.0
z[7]         0.797  0.449  -0.025    1.637  ...    0.008    1209.0    1888.0    1.0
x[1]        -0.661  0.901  -2.397    1.013  ...    0.017    3491.0    2390.0    1.0
x[2]         0.155  0.856  -1.486    1.801  ...    0.015    4463.0    2970.0    1.0
x[3]         0.205  0.861  -1.421    1.851  ...    0.014    3690.0    2900.0    1.0
x[4]         0.030  0.830  -1.537    1.583  ...    0.014    3834.0    2445.0    1.0
x[5]        -0.135  0.859  -1.641    1.630  ...    0.015    4312.0    2921.0    1.0
x[6]         0.432  0.873  -1.362    1.958  ...    0.014    3104.0    2870.0    1.0
beta[0_0]   -0.143  0.296  -0.755    0.374  ...    0.004    2073.0    2794.0    1.0
beta[0_1]    0.382  0.301  -0.196    0.939  ...    0.004    2181.0    2777.0    1.0
beta[1_0]   -0.494  0.300  -1.069    0.061  ...    0.004    2128.0    2707.0    1.0
beta[1_1]    0.269  0.297  -0.289    0.807  ...    0.005    2088.0    2549.0    1.0
sigma_a      2.017  0.636   0.969    3.180  ...    0.016    1134.0    1933.0    1.0
sigma_gamma  0.202  0.176   0.000    0.513  ...    0.005    1359.0    1408.0    1.0

[20 rows x 9 columns]

Check the forest plots, i.e., posterior coefficient plots:

  • R
  • Python
m3 %>% mcmc_intervals(regex_pars = "^b_.*")
m3 %>% mcmc_areas(regex_pars = "^b_.*")

az.plot_forest(chimp_model_trace, kind='forestplot', combined=True, hdi_prob=0.95)
az.plot_forest(chimp_model_trace, kind='ridgeplot', combined=True, hdi_prob=0.95)
array([<Axes: title={'center': '95.0% HDI'}>], dtype=object)
array([<Axes: >], dtype=object)

Perform a posterior predictive check:

  • R
  • Python
pp_check(m3, ndraws = 100)
pp_check(m3, ndraws = 100, type = "ecdf_overlay")
pp_check(m3, ndraws = 100, type = "pit_ecdf")

azp.plot_ppc_dist(chimp_model_trace, group="posterior_predictive", num_samples=100)
<string>:1: UserWarning: Detected at least one discrete variable.
Consider using plot_ppc variants specific for discrete data, such as plot_ppc_pava or plot_ppc_rootogram.
<arviz_plots.plot_collection.PlotCollection object at 0x14ec262c0>

© 2024 Sheng Long

 

This website is built with , , Quarto, fontawesome, iconify.design, and faviconer.