Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding more detailed info to GraphViz renderings #3956

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6504275
treat array of shape (1,) as scalar
Spaak Jun 12, 2020
c6d5cdf
adding option to display parameters of distributions to graphviz gene…
Spaak Jun 12, 2020
b436ee5
adding generic machinery for generating string representations
Spaak Aug 10, 2020
6d0b450
double quotes for strings
Spaak Aug 10, 2020
19e8d77
Update pymc3/distributions/distribution.py
Spaak Aug 11, 2020
1b228f6
Merge branch 'more-graph-details' of github.com:Spaak/pymc3 into more…
Spaak Aug 11, 2020
1e047eb
restoring return None behavior of TransformedDistribution::_repr_latex_
Spaak Aug 11, 2020
4aa4df1
extra . for import
Spaak Aug 11, 2020
7173261
renaming _distr_parameters() to _distr_parameters_for_repr() to avoid…
Spaak Aug 11, 2020
a8e8701
replacing old _repr_latex_ functionality with new one
Spaak Aug 11, 2020
451cc05
adding new repr functionality to Deterministic
Spaak Aug 11, 2020
4cfaf06
replacing old with new str repr functionality in PyMC3Variable
Spaak Aug 11, 2020
cb12783
ensure that TransformedDistribution does not mess up its str repr
Spaak Aug 11, 2020
8f79407
new str repr functionality in Model
Spaak Aug 11, 2020
8641934
adding unit tests for new __str__ functionality
Spaak Aug 11, 2020
39488b2
updating graphviz rendering with new repr functionality
Spaak Aug 11, 2020
8026f27
updating unit tests to reflect new graph info
Spaak Aug 11, 2020
3589c89
ensure that variable keys always use the simple plain variable name, …
Spaak Aug 11, 2020
00d9f1c
removing unused import
Spaak Aug 11, 2020
273f1b6
ensure usage of plain string (i.e. names) of variables rather than us…
Spaak Aug 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import numbers
import contextvars
import inspect
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable

import numpy as np
import theano.tensor as tt
from theano import function
from pymc3.util import get_variable_name
import theano
from ..memoize import memoize
from ..model import (
Expand Down Expand Up @@ -134,9 +136,48 @@ def getattr_value(self, val):

return val

def _repr_latex_(self, name=None, dist=None):
def _distr_parameters(self):
"""Return the names of the parameters for this distribution (e.g. "mu"
and "sigma" for Normal). Used in generating string (and LaTeX etc.)
representations of Distribution objects. By default based on inspection
of __init__, but can be overwritten if necessary (e.g. to avoid including
"sd" and "tau").
"""
return inspect.getfullargspec(self.__init__).args[1:]

def _distr_name(self):
return self.__class__.__name__

def _str_repr(self, name=None, dist=None, formatting='plain'):
"""Generate string representation for this distribution, optionally
including LaTeX markup (formatting='latex').
"""
if dist is None:
dist = self
if name is None:
name = '[unnamed]'

param_names = self._distr_parameters()
param_values = [get_variable_name(getattr(dist, x)) for x in param_names]

if formatting == "latex":
param_string = ",~".join([r"\mathit{{{name}}}={value}".format(name=name,
value=value) for name, value in zip(param_names, param_values)])
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(var_name=name,
distr_name=dist._distr_name(), params=param_string)
else:
# 'plain' is default option
param_string = ", ".join(["{name}={value}".format(name=name,
value=value) for name, value in zip(param_names, param_values)])
return "{var_name} ~ {distr_name}({params})".format(var_name=name,
distr_name=dist._distr_name(), params=param_string)

def __str__(self, **kwargs):
return self._str_repr(formatting='plain', **kwargs)

def _repr_latex_(self, **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
return None
return self._str_repr(formatting='latex', **kwargs)

def logp_nojac(self, *args, **kwargs):
"""Return the logp, but do not include a jacobian term for transforms.
Expand Down
35 changes: 26 additions & 9 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .util import get_default_varnames
from .model import ObservedRV
import pymc3 as pm
from pymc3.util import get_variable_name


class ModelGraph:
Expand All @@ -33,6 +34,10 @@ def __init__(self, model):
self.var_list = self.model.named_vars.values()
self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')}
self._deterministics = None
self._distr_params = {
'Normal': ['mu', 'sigma'],
'Uniform': ['lower', 'upper'],
}

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph, **not** including var itself."""
Expand Down Expand Up @@ -120,7 +125,7 @@ def update_input_map(key: str, val: Set[VarName]):
pass
return input_map

def _make_node(self, var_name, graph):
def _make_node(self, var_name, graph, include_prior_params):
"""Attaches the given variable to a graphviz Digraph"""
v = self.model[var_name]

Expand All @@ -146,9 +151,21 @@ def _make_node(self, var_name, graph):
distribution = 'Deterministic'
attrs['shape'] = 'box'

graph.node(var_name.replace(':', '&'),
'{var_name}\n~\n{distribution}'.format(var_name=var_name, distribution=distribution),
**attrs)
node_text = '{var_name}\n~\n{distribution}'.format(var_name=var_name, distribution=distribution)
if include_prior_params and distribution in self._distr_params:
param_strings = []
for param in self._distr_params[distribution]:
val = get_variable_name(getattr(v.distribution, param))
if type(val) is str and len(val) > 100:
val = '<long expression>'
try:
val = '{val:.3g}'.format(val=float(val))
except ValueError:
pass
param_strings.append('{param}={val}'.format(param=param,
val=val))
node_text += '(' + ', '.join(param_strings) + ')'
graph.node(var_name.replace(':', '&'), node_text, **attrs)

def get_plates(self):
""" Rough but surprisingly accurate plate detection.
Expand Down Expand Up @@ -181,7 +198,7 @@ def get_plates(self):
plates[shape].add(var_name)
return plates

def make_graph(self):
def make_graph(self, include_prior_params=False):
"""Make graphviz Digraph of PyMC3 model

Returns
Expand All @@ -203,20 +220,20 @@ def make_graph(self):
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name='cluster' + label) as sub:
for var_name in var_names:
self._make_node(var_name, sub)
self._make_node(var_name, sub, include_prior_params)
# plate label goes bottom right
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded')
else:
for var_name in var_names:
self._make_node(var_name, graph)
self._make_node(var_name, graph, include_prior_params)

for key, values in self.make_compute_graph().items():
for value in values:
graph.edge(value.replace(':', '&'), key.replace(':', '&'))
return graph


def model_to_graphviz(model=None):
def model_to_graphviz(model=None, **kwargs):
"""Produce a graphviz Digraph from a PyMC3 model.

Requires graphviz, which may be installed most easily with
Expand All @@ -228,4 +245,4 @@ def model_to_graphviz(model=None):
for more information.
"""
model = pm.modelcontext(model)
return ModelGraph(model).make_graph()
return ModelGraph(model).make_graph(**kwargs)
2 changes: 1 addition & 1 deletion pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_variable_name(variable):
except IndexError:
pass
value = variable.eval()
if not value.shape:
if not value.shape or value.shape == (1,):
return asscalar(value)
return "array"
return r"\text{%s}" % name
Expand Down