Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Using transform=pm.distributions.transforms.ordered together with non-null dims results in AssertionError #7554

Closed
tomicapretto opened this issue Oct 29, 2024 · 4 comments · Fixed by pymc-devs/pytensor#1057
Labels

Comments

@tomicapretto
Copy link
Contributor

tomicapretto commented Oct 29, 2024

Describe the issue:

Someone opened an issue in the Bambi repo showing that trying to fit certain model resulted in an AssertionError. I reproduced the model in PyMC and found the issue. When we pass something to dims and use transform=pm.distributions.transforms.ordered, it causes the error.

Reproduceable code example:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "threshold_dim": [0, 1],
    "to_predict_dim": [0, 1, 2],
    "__obs__": [0, 1, 2],
}

predictor = np.array([1, 0, 1])
observed = np.array([0, 1, 2])

with pm.Model(coords=coords) as model:
    b_predictor = pm.Normal("b_predictor")
    threshold = pm.Normal(
        "threshold",
        mu=[-2, 2],
        sigma=1,
        transform=pm.distributions.transforms.ordered,
        # dims="threshold_dim" # If this is commented out, we get the assertion error
    )

    eta = b_predictor * np.array([1, 0, 1])
    eta_shifted = threshold - pt.shape_padright(eta)
    p = pm.math.sigmoid(eta_shifted)
    p = pt.concatenate(
        [
            pt.shape_padright(p[..., 0]),
            p[..., 1:] - p[..., :-1],
            pt.shape_padright(1 - p[..., -1]),
        ],
        axis=-1,
    )

    p = pm.Deterministic("p", p, dims=("__obs__", "to_predict_dim"))

    pm.Categorical("to_predict", p=p, observed=observed, dims="__obs__")

with model:
    idata = pm.sample()

Error message:

When you comment out the line highlighted above, you'll see the following error message:

AssertionError                            Traceback (most recent call last)
Cell In[7], line 2
      1 with model:
----> 2     idata = pm.sample()

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:718, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    715         auto_nuts_init = False
    717 initial_points = None
--> 718 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    720 if nuts_sampler != "pymc":
    721     if not isinstance(step, NUTS):

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:223, in assign_step_methods(model, step, methods, step_kwargs)
    221 if has_gradient:
    222     try:
--> 223         tg.grad(model_logp, var)  # type: ignore
    224     except (NotImplementedError, tg.NullTypeGradError):
    225         has_gradient = False

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:633, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    630     if hasattr(g.type, "dtype"):
    631         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 633 _rval: Sequence[Variable] = _populate_grad_dict(
    634     var_to_app_to_idx, grad_dict, _wrt, cost_name
    635 )
    637 rval: MutableSequence[Variable | None] = list(_rval)
    639 for i in range(len(_rval)):

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1422     # end if cache miss
   1423     return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
   1427 return rval

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in <listcomp>(.0)
   1422     # end if cache miss
   1423     return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
   1427 return rval

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1210, in _populate_grad_dict.<locals>.access_term_cache(node)
   1202         if o_shape != g_shape:
   1203             raise ValueError(
   1204                 "Got a gradient of shape "
   1205                 + str(o_shape)
   1206                 + " on an output of shape "
   1207                 + str(g_shape)
   1208             )
-> 1210 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1212 if input_grads is None:
   1213     raise TypeError(
   1214         f"{node.op}.grad returned NoneType, expected iterable."
   1215     )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/op.py:398, in Op.L_op(self, inputs, outputs, output_grads)
    371 def L_op(
    372     self,
    373     inputs: Sequence[Variable],
    374     outputs: Sequence[Variable],
    375     output_grads: Sequence[Variable],
    376 ) -> list[Variable]:
    377     r"""Construct a graph for the L-operator.
    378 
    379     The L-operator computes a row vector times the Jacobian.
   (...)
    396 
    397     """
--> 398     return self.grad(inputs, output_grads)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:1995, in IncSubtensor.grad(self, inputs, grads)
   1993         gx = g_output
   1994     gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
-> 1995     gy = _sum_grad_over_bcasted_dims(y, gy)
   1997 return [gx, gy] + [DisconnectedType()()] * len(idx_list)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:2031, in _sum_grad_over_bcasted_dims(x, gx)
   2029 x_dim_added = gx.ndim - x.ndim
   2030 x_broad = (True,) * x_dim_added + x.broadcastable
-> 2031 assert sum(gx.broadcastable) <= sum(x_broad)
   2032 axis_to_sum = []
   2033 for i in range(gx.ndim):

AssertionError: 

But if you leave it commented, it works.

PyMC version information:

PyMC 5.17.0
PyTensor 2.25.5

Context for the issue:

No response

@ricardoV94
Copy link
Member

This was probably fixed recently in pymc-devs/pytensor#1036 and should be in the next release of PyTensor + bump on PyMC dependency

@tomicapretto
Copy link
Contributor Author

Thanks. I just tested it and got a new error, probably relevant for you.

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 960, in __call__
    self.vm()
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 128, in run
    self._start_loop()
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 180, in _start_loop
    point, stats = self._step_method.step(self._point)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 173, in step
    return super().step(point)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 101, in step
    apoint, stats = self.astep(q)
                    ^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/hmc/base_hmc.py", line 168, in astep
    start = self.integrator.compute_state(q0, p0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/hmc/integration.py", line 56, in compute_state
    logp, dlogp = self._logp_dlogp_func(q)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/model/core.py", line 359, in __call__
    cost, *grads = self._pytensor_function(*grad_vars)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 973, in __call__
    raise_with_op(
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/link/utils.py", line 524, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 960, in __call__
    self.vm()
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 81
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([ 0.26211354, -0.50627326]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
"""

The above exception was the direct cause of the following exception:

AssertionError                            Traceback (most recent call last)
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 81
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([ 0.26211354, -0.50627326]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The above exception was the direct cause of the following exception:

ParallelSamplingError                     Traceback (most recent call last)
Cell In[10], line 2
      1 with model:
----> 2     idata = pm.sample()

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:848, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    846 _print_step_hierarchy(step)
    847 try:
--> 848     _mp_sample(**sample_args, **parallel_args)
    849 except pickle.PickleError:
    850     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:1261, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1259 try:
   1260     with sampler:
-> 1261         for draw in sampler:
   1262             strace = traces[draw.chain]
   1263             strace.record(draw.point, draw.stats)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py:471, in ParallelSampler.__iter__(self)
    464 task = progress.add_task(
    465     self._desc.format(self),
    466     completed=self._completed_draws,
    467     total=self._total_draws,
    468 )
    470 while self._active:
--> 471     draw = ProcessAdapter.recv_draw(self._active)
    472     proc, is_last, draw, tuning, stats = draw
    473     self._completed_draws += 1

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py:338, in ProcessAdapter.recv_draw(processes, timeout)
    336     else:
    337         error = RuntimeError(f"Chain {proc.chain} failed.")
--> 338     raise error from old_error
    339 elif msg[0] == "writing_done":
    340     proc._readable = True

ParallelSamplingError: Chain 1 failed with: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 81
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([ 0.26211354, -0.50627326]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

And if I also add shape=2 to what already have dims, I get the following (very long) message

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_reduce_join
ERROR (pytensor.graph.rewriting.basic): node: Sum{axis=1}(Join.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/math.py", line 1668, in local_reduce_join
    new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 374, in apply_local_dimshuffle_lift
    new = local_dimshuffle_lift.transform(fgraph, var.owner)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 432, in local_dimshuffle_lift
    and (len(fgraph.clients[inp]) == 1)
             ^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'clients'

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [b_predictor, threshold]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:-- / 0:00:00

And the following error

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 960, in __call__
    self.vm()
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 128, in run
    self._start_loop()
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 180, in _start_loop
    point, stats = self._step_method.step(self._point)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 173, in step
    return super().step(point)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 101, in step
    apoint, stats = self.astep(q)
                    ^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/hmc/base_hmc.py", line 168, in astep
    start = self.integrator.compute_state(q0, p0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/step_methods/hmc/integration.py", line 56, in compute_state
    logp, dlogp = self._logp_dlogp_func(q)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/model/core.py", line 359, in __call__
    cost, *grads = self._pytensor_function(*grad_vars)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 973, in __call__
    raise_with_op(
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/link/utils.py", line 524, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/tomas/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 960, in __call__
    self.vm()
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 70
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([-0.65585729, -1.55057568]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
"""

The above exception was the direct cause of the following exception:

AssertionError                            Traceback (most recent call last)
AssertionError: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 70
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([-0.65585729, -1.55057568]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The above exception was the direct cause of the following exception:

ParallelSamplingError                     Traceback (most recent call last)
Cell In[8], line 2
      1 with model:
----> 2     idata = pm.sample()

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:848, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    846 _print_step_hierarchy(step)
    847 try:
--> 848     _mp_sample(**sample_args, **parallel_args)
    849 except pickle.PickleError:
    850     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:1261, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1259 try:
   1260     with sampler:
-> 1261         for draw in sampler:
   1262             strace = traces[draw.chain]
   1263             strace.record(draw.point, draw.stats)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py:471, in ParallelSampler.__iter__(self)
    464 task = progress.add_task(
    465     self._desc.format(self),
    466     completed=self._completed_draws,
    467     total=self._total_draws,
    468 )
    470 while self._active:
--> 471     draw = ProcessAdapter.recv_draw(self._active)
    472     proc, is_last, draw, tuning, stats = draw
    473     self._completed_draws += 1

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/parallel.py:338, in ProcessAdapter.recv_draw(processes, timeout)
    336     else:
    337         error = RuntimeError(f"Chain {proc.chain} failed.")
--> 338     raise error from old_error
    339 elif msg[0] == "writing_done":
    340     proc._readable = True

ParallelSamplingError: Chain 0 failed with: SpecifyShape: dim 0 of input has shape 2, expected 1.
Apply node that caused the error: SpecifyShape(Subtensor{:stop:step}.0, 1)
Toposort index: 70
Inputs types: [TensorType(float64, shape=(2,)), TensorType(int8, shape=())]
Inputs shapes: [(2,), ()]
Inputs strides: [(-8,), ()]
Inputs values: [array([-0.65585729, -1.55057568]), array(1, dtype=int8)]
Outputs clients: [[Composite{(1.0 + (i0 * i1))}(SpecifyShape.0, Exp.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

@ricardoV94
Copy link
Member

The rewrite bug is obvious, the other one not sure. Have to see why the SpecifyShape is introduced

@tomicapretto
Copy link
Contributor Author

This is the MRE

import pymc as pm

coords = {"threshold_dim": [0, 1]}

with pm.Model(coords=coords) as model:
    threshold = pm.Normal(
        "threshold",
        mu=[-2, 2],
        sigma=1,
        transform=pm.distributions.transforms.ordered,
        dims="threshold_dim",
    )

    idata = pm.sample()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants