diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index e817a764961..b8ff4de5379 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -563,30 +563,30 @@ def draw_values(params, point=None, size=None): # specified in the point. Need to find the node-inputs, their # parents and children to replace them. leaf_nodes = {} - named_nodes_parents = {} - named_nodes_children = {} + named_nodes_descendents = {} + named_nodes_ancestors = {} for _, param in symbolic_params: if hasattr(param, 'name'): # Get the named nodes under the `param` node - nn, nnp, nnc = get_named_nodes_and_relations(param) + nn, nnd, nna = get_named_nodes_and_relations(param) leaf_nodes.update(nn) # Update the discovered parental relationships - for k in nnp.keys(): - if k not in named_nodes_parents.keys(): - named_nodes_parents[k] = nnp[k] + for k in nnd.keys(): + if k not in named_nodes_descendents.keys(): + named_nodes_descendents[k] = nnd[k] else: - named_nodes_parents[k].update(nnp[k]) + named_nodes_descendents[k].update(nnd[k]) # Update the discovered child relationships - for k in nnc.keys(): - if k not in named_nodes_children.keys(): - named_nodes_children[k] = nnc[k] + for k in nna.keys(): + if k not in named_nodes_ancestors.keys(): + named_nodes_ancestors[k] = nna[k] else: - named_nodes_children[k].update(nnc[k]) - stack = [k for k, v in named_nodes_children.items() if len(v) == 0] + named_nodes_ancestors[k].update(nna[k]) # Init givens and the stack of nodes to try to `_draw_value` from givens = {p.name: (p, v) for (p, size), v in drawn.items() if getattr(p, 'name', None) is not None} + stack = list(leaf_nodes.values()) while stack: next_ = stack.pop(0) if (next_, size) in drawn: @@ -607,7 +607,7 @@ def draw_values(params, point=None, size=None): # of TensorConstants or SharedVariables, we must add them # to the stack or risk evaluating deterministics with the # wrong values (issue #3354) - stack.extend([node for node in named_nodes_parents[next_] + stack.extend([node for node in named_nodes_descendents[next_] if isinstance(node, (ObservedRV, MultiObservedRV)) and (node, size) not in drawn]) @@ -616,7 +616,7 @@ def draw_values(params, point=None, size=None): # If the node does not have a givens value, try to draw it. # The named node's children givens values must also be taken # into account. - children = named_nodes_children[next_] + children = named_nodes_ancestors[next_] temp_givens = [givens[k] for k in givens if k in children] try: # This may fail for autotransformed RVs, which don't @@ -631,7 +631,7 @@ def draw_values(params, point=None, size=None): # The node failed, so we must add the node's parents to # the stack of nodes to try to draw from. We exclude the # nodes in the `params` list. - stack.extend([node for node in named_nodes_parents[next_] + stack.extend([node for node in named_nodes_descendents[next_] if node is not None and (node, size) not in drawn]) @@ -655,8 +655,8 @@ def draw_values(params, point=None, size=None): # This may set values for certain nodes in the drawn # dictionary, but they don't get added to the givens # dictionary. Here, we try to fix that. - if param in named_nodes_children: - for node in named_nodes_children[param]: + if param in named_nodes_ancestors: + for node in named_nodes_ancestors[param]: if ( node.name not in givens and (node, size) in drawn diff --git a/pymc3/model.py b/pymc3/model.py index 00423e5514c..2c292d52e00 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -91,6 +91,7 @@ def incorporate_methods(source, destination, methods, default=None, else: setattr(destination, method, None) + def get_named_nodes_and_relations(graph): """Get the named nodes in a theano graph (i.e., nodes whose name attribute is not None) along with their relationships (i.e., the @@ -102,64 +103,70 @@ def get_named_nodes_and_relations(graph): graph - a theano node Returns: - leaf_nodes: A dictionary of name:node pairs, of the named nodes that - are also leafs of the graph - node_parents: A dictionary of node:set([parents]) pairs. Each key is + leafs: A dictionary of name:node pairs, of the named nodes that + have no named ancestors in the provided theano graph. + descendents: A dictionary of node:set([parents]) pairs. Each key is a theano named node, and the corresponding value is the set of - theano named nodes that are parents of the node. These parental - relations skip unnamed intermediate nodes. - node_children: A dictionary of node:set([children]) pairs. Each key + theano named nodes that are direct descendents of the node in the + supplied ``graph``. These relations skip unnamed intermediate nodes. + ancestors: A dictionary of node:set([ancestors]) pairs. Each key is a theano named node, and the corresponding value is the set - of theano named nodes that are children of the node. These child - relations skip unnamed intermediate nodes. + of theano named nodes that are direct ancestors in the of the node in + the supplied ``graph``. These relations skip unnamed intermediate + nodes. """ if graph.name is not None: - node_parents = {graph: set()} - node_children = {graph: set()} + ancestors = {graph: set()} + descendents = {graph: set()} else: - node_parents = {} - node_children = {} - return _get_named_nodes_and_relations(graph, None, {}, node_parents, node_children) - -def _get_named_nodes_and_relations(graph, parent, leaf_nodes, - node_parents, node_children): + ancestors = {} + descendents = {} + descendents, ancestors = _get_named_nodes_and_relations( + graph, None, ancestors, descendents + ) + leafs = { + node.name: node for node, ancestor in ancestors.items() + if len(ancestor) == 0 + } + return leafs, descendents, ancestors + + +def _get_named_nodes_and_relations(graph, descendent, descendents, ancestors): if getattr(graph, 'owner', None) is None: # Leaf node if graph.name is not None: # Named leaf node - leaf_nodes.update({graph.name: graph}) - if parent is not None: # Is None for the root node + if descendent is not None: # Is None for the first node try: - node_parents[graph].add(parent) + descendents[graph].add(descendent) except KeyError: - node_parents[graph] = {parent} - node_children[parent].add(graph) + descendents[graph] = {descendent} + ancestors[descendent].add(graph) else: - node_parents[graph] = set() + descendents[graph] = set() # Flag that the leaf node has no children - node_children[graph] = set() + ancestors[graph] = set() else: # Intermediate node if graph.name is not None: # Intermediate named node - if parent is not None: # Is only None for the root node + if descendent is not None: # Is only None for the root node try: - node_parents[graph].add(parent) + descendents[graph].add(descendent) except KeyError: - node_parents[graph] = {parent} - node_children[parent].add(graph) + descendents[graph] = {descendent} + ancestors[descendent].add(graph) else: - node_parents[graph] = set() - # The current node will be set as the parent of the next + descendents[graph] = set() + # The current node will be set as the descendent of the next # nodes only if it is a named node - parent = graph + descendent = graph # Init the nodes children to an empty set - node_children[graph] = set() + ancestors[graph] = set() for i in graph.owner.inputs: - temp_nodes, temp_inter, temp_tree = \ - _get_named_nodes_and_relations(i, parent, leaf_nodes, - node_parents, node_children) - leaf_nodes.update(temp_nodes) - node_parents.update(temp_inter) - node_children.update(temp_tree) - return leaf_nodes, node_parents, node_children + temp_desc, temp_ances = _get_named_nodes_and_relations( + i, descendent, descendents, ancestors + ) + descendents.update(temp_desc) + ancestors.update(temp_ances) + return descendents, ancestors class Context: