Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update effect_handlers.ipynb (#3296)
```python print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5))) ``` prevents the following error message: ```python --------------------------------------------------------------------------- ValueError Traceback (most recent call last) File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter) 195 try: --> 196 log_p = site["fn"].log_prob( 197 site["value"], *site["args"], **site["kwargs"] 198 ) 199 except ValueError as e: File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value) 78 if self._validate_args: ---> 79 self._validate_sample(value) 80 # compute the variance File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value) 270 if not isinstance(value, torch.Tensor): --> 271 raise ValueError('The value argument to log_prob must be a Tensor') 273 event_dim_start = len(value.size()) - len(self._event_shape) ValueError: The value argument to log_prob must be a Tensor The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) Cell In[5], line 9 6 return _log_joint 8 scale_log_joint = make_log_joint(scale) ----> 9 print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)) Cell In[5], line 5, in make_log_joint.<locals>._log_joint(cond_data, *args, **kwargs) 3 conditioned_model = poutine.condition(model, data=cond_data) 4 trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs) ----> 5 return trace.log_prob_sum() File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:202, in Trace.log_prob_sum(self, site_filter) 200 _, exc_value, traceback = sys.exc_info() 201 shapes = self.format_shapes(last_site=site["name"]) --> 202 raise ValueError( 203 "Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format( 204 name, exc_value, shapes 205 ) 206 ).with_traceback(traceback) from e 207 log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum() 208 site["log_prob_sum"] = log_p File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter) 194 else: 195 try: --> 196 log_p = site["fn"].log_prob( 197 site["value"], *site["args"], **site["kwargs"] 198 ) 199 except ValueError as e: 200 _, exc_value, traceback = sys.exc_info() File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value) 77 def log_prob(self, value): 78 if self._validate_args: ---> 79 self._validate_sample(value) 80 # compute the variance 81 var = (self.scale ** 2) File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value) 257 """ 258 Argument validation for distribution methods such as `log_prob`, 259 `cdf` and `icdf`. The rightmost dimensions of a value to be (...) 268 distribution's batch and event shapes. 269 """ 270 if not isinstance(value, torch.Tensor): --> 271 raise ValueError('The value argument to log_prob must be a Tensor') 273 event_dim_start = len(value.size()) - len(self._event_shape) 274 if value.size()[event_dim_start:] != self._event_shape: ValueError: Error while computing log_prob_sum at site 'weight': The value argument to log_prob must be a Tensor Trace Shapes: Param Sites: Sample Sites: weight dist | value | ```
- Loading branch information