From 908d256dcef207c8f6911797f095b9cccaf0be47 Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Mon, 15 Jun 2015 10:52:27 -0400 Subject: [PATCH 01/18] Added a new likelihood weighter class into the experimental package. This class will eventually replace all of the sampling in Importance sampling, Forward sampling, particle filtering, and probability of evidence sampling. It also improves on the existing likelihood weighting by allowing weights to be pushed through non-caching chains (currently only support caching chains). Right now though, it is slightly slower because of the non-caching chain change. So it will not be enabled until caching is removed from chains and put into control of the algorithms. --- .../sampling/LikelihoodWeighter.scala | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala b/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala new file mode 100644 index 00000000..baf70cbf --- /dev/null +++ b/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala @@ -0,0 +1,133 @@ +package com.cra.figaro.experimental.sampling + +import com.cra.figaro.language._ +import scala.annotation.tailrec +import com.cra.figaro.algorithm.sampling.Importance + +class LikelihoodWeighter(universe: Universe) { + + def computeWeight(elementsToVisit: Set[Element[_]]): Double = { + traverse(List(), elementsToVisit, 0.0, Set(), scala.collection.mutable.Map[Dist[_, _], Int]()) + } + + @tailrec + private final def traverse(currentStack: List[(Element[_], Option[_])], + elementsToVisit: Set[Element[_]], + currentWeight: Double, + visited: Set[Element[_]], DistMap: scala.collection.mutable.Map[Dist[_, _], Int]): Double = { + + // If everything is empty, just return the weight + if (elementsToVisit.isEmpty && currentStack.isEmpty) { + currentWeight + } + // If the current stack is empty, we are free to choose any element to traverse. Pick the head of the set + else if (currentStack.isEmpty) { + traverse(List((elementsToVisit.head, getObservation(elementsToVisit.head, None))), elementsToVisit.tail, currentWeight, visited, DistMap) + } + // If the head of the stack has already been visited or in another universe, we don't need to do anything, go to the next element + else if (visited.contains(currentStack.head._1) || currentStack.head._1.universe != universe) { + traverse(currentStack.tail, elementsToVisit, currentWeight, visited, DistMap) + } + // Otherwise, we need to process the top element on the stack + else { + val (currElem, currObs) = currentStack.head + + currElem match { + case d: Dist[_, _] => + val parents = d match { + case dc: CompoundDist[_] => dc.probs.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None))) + case _ => List() + } + val rand = d.generateRandomness() + val index = d.selectIndex(rand) + val resultElement = d.outcomeArray(index) + val nextHead = List((resultElement, getObservation(resultElement, currObs)), (currElem, None)) + + if (parents.nonEmpty) { + traverse(parents ::: currentStack, elementsToVisit, currentWeight, visited, DistMap) + } else if (visited.contains(resultElement) && currObs.nonEmpty) { + traverse(nextHead ::: currentStack.tail, elementsToVisit, undoWeight(currentWeight, resultElement), visited - resultElement, DistMap += (d -> index)) + } else if (!visited.contains(resultElement)) { + traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited, DistMap += (d -> index)) + } else { + d.value = if (DistMap.contains(d)) d.finishGeneration(DistMap(d)) else d.finishGeneration(index) + DistMap -= d + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) + } + case c: Chain[_, _] => + if (!visited.contains(c.parent)) { + traverse((c.parent, getObservation(c.parent, None)) +: currentStack, elementsToVisit, currentWeight, visited, DistMap) + } else { + val next = c.get(c.parent.value) + val nextHead = List((next, getObservation(next, currObs)), (currElem, None)) + if (visited.contains(next) && currObs.nonEmpty) { + // we did this in the wrong order, and have to repropagate the result for likelihood weighting + traverse(nextHead ::: currentStack.tail, elementsToVisit, undoWeight(currentWeight, next), visited - next, DistMap) + } else if (!visited.contains(next)) { + traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited, DistMap) + } else { + c.value = next.value + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) + } + } + case _ => + val args = (currElem.args ::: currElem.elementsIAmContingentOn.toList) + // Find all the arguments of the element that have not been visited + val remainingArgs = args.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None))) + // if there are args unvisited, push those args to the top of the stack + if (remainingArgs.nonEmpty) { + traverse(remainingArgs ::: currentStack, elementsToVisit, currentWeight, visited, DistMap) + } else { + // else, we can now process this element and move on to the next item + currElem.randomness = currElem.generateRandomness() + currElem.value = currElem.generateValue(currElem.randomness) + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) + } + } + + } + } + + def getObservation(element: Element[_], observation: Option[_]) = { + (observation, element.observation) match { + case (None, None) => None + case (Some(obs), None) => Some(obs) + case (None, Some(obs)) => Some(obs) + case (Some(obs1), Some(obs2)) if obs1 == obs2 => Some(obs1) + case _ => throw Importance.Reject // incompatible observations + } + } + + def computeNextWeight(currentWeight: Double, element: Element[_], obs: Option[_]): Double = { + val nextWeight = if (obs.isEmpty) { + if (!element.condition(element.value)) throw Importance.Reject + currentWeight + } else { + element match { + case f: CompoundFlip => { + element.value = obs.get.asInstanceOf[element.Value] + if (obs.get.asInstanceOf[Boolean]) currentWeight + math.log(f.prob.value) + else currentWeight + math.log(1 - f.prob.value) + } + case e: HasDensity[_] => { + element.value = obs.get.asInstanceOf[element.Value] + val density = element.asInstanceOf[HasDensity[element.Value]].density(obs.asInstanceOf[Option[element.Value]].get) + currentWeight + math.log(density) + } + case _ => { + if (!element.condition(element.value)) throw Importance.Reject + currentWeight + } + } + } + nextWeight + element.constraint(element.value) + } + + def undoWeight(weight: Double, elem: Element[_]) = weight - computeNextWeight(0.0, elem, elem.observation) + +} + + From dcb165ef472047edd30e4f3e8e7c217a49543b16 Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Tue, 16 Jun 2015 16:48:26 -0400 Subject: [PATCH 02/18] Removed caching from chains, implemented class to perform caching, and added changes to MH accordingly. --- .../figaro/algorithm/sampling/Forward.scala | 53 +++++--- .../sampling/MetropolisHastings.scala | 87 ++++++------ .../sampling/MetropolisHastingsAnnealer.scala | 4 +- .../scala/com/cra/figaro/language/Chain.scala | 38 ++++-- .../com/cra/figaro/language/Element.scala | 4 +- .../com/cra/figaro/language/Universe.scala | 4 - .../figaro/library/decision/Decision.scala | 2 +- .../scala/com/cra/figaro/util/Cache.scala | 128 ++++++++++++++++++ .../test/example/OpenUniverseTest.scala | 3 +- .../figaro/test/language/ElementsTest.scala | 63 +++++---- 10 files changed, 274 insertions(+), 112 deletions(-) create mode 100644 Figaro/src/main/scala/com/cra/figaro/util/Cache.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala index e1b73a16..75247fc0 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala @@ -14,6 +14,9 @@ package com.cra.figaro.algorithm.sampling import com.cra.figaro.language._ +import com.cra.figaro.util.ChainCache +import com.cra.figaro.util.Cache +import com.cra.figaro.util.NoCache /** * A forward sampler that generates a state by generating values for elements, making sure to generate all the @@ -21,22 +24,30 @@ import com.cra.figaro.language._ */ object Forward { /** - * Sample the universe by generating a value for each element of the universe. + * Sample the universe by generating a value for each element of the universe. Return a cache object. */ - def apply(implicit universe: Universe): Unit = apply(false)(universe) - - def apply(useObservation: Boolean)(implicit universe: Universe): Unit = { - // avoiding recursion + def apply(universe: Universe): Cache = { + apply(universe, new NoCache(universe)) + } + + /** + * Sample the universe by generating a value for each element of the universe, and provide a cache object. Return a cache object. + */ + def apply(universe: Universe, cache: Cache): Cache = { + // avoiding recursion var state = Set[Element[_]]() var elementsRemaining = universe.activeElements while (!elementsRemaining.isEmpty) { - if (elementsRemaining.head.active) state = sampleInState(elementsRemaining.head, state, universe, useObservation) + if (elementsRemaining.head.active) state = sampleInState(elementsRemaining.head, state, universe, cache) elementsRemaining = elementsRemaining.tail } + cache } - - def apply[T](element: Element[T], useObservation: Boolean = false) = { - sampleInState(element, Set[Element[_]](), element.universe, useObservation) + + def apply[T](element: Element[T]) = { + val noCache = new NoCache(element.universe) + sampleInState(element, Set[Element[_]](), element.universe, noCache) + noCache } private type State = Set[Element[_]] @@ -45,7 +56,7 @@ object Forward { * To allow this algorithm to be used for dependent universes, we make sure elements in a different universe are not * sampled. */ - private def sampleInState[T](element: Element[T], state: State, universe: Universe, useObservation: Boolean): State = { + private def sampleInState[T](element: Element[T], state: State, universe: Universe, cache: Cache): State = { if (element.universe != universe || (state contains element)) state else { val (state1, sampledValue) = { @@ -58,7 +69,7 @@ object Forward { var resultState = state var probsRemaining = dc.probs while (!probsRemaining.isEmpty) { - resultState = sampleInState(probsRemaining.head, resultState, universe, useObservation) + resultState = sampleInState(probsRemaining.head, resultState, universe, cache) probsRemaining = probsRemaining.tail } resultState @@ -66,20 +77,23 @@ object Forward { } val rand = d.generateRandomness() val index = d.selectIndex(rand) - val state2 = sampleInState(d.outcomeArray(index), state1, universe, useObservation) + val state2 = sampleInState(d.outcomeArray(index), state1, universe, cache) (state2, d.finishGeneration(index)) case c: Chain[_, _] => - val state1 = sampleInState(c.parent, state, universe, useObservation) - val result = c.get(c.parent.value) - val state2 = sampleInState(result, state1, universe, useObservation) + val state1 = sampleInState(c.parent, state, universe, cache) + val result = cache(c) match { + case Some(r) => r + case _ => c.get(c.parent.value) + } + val state2 = sampleInState(result, state1, universe, cache) (state2, result.value) case _ => // avoiding recursion var state1 = state var initialArgs = (element.args ::: element.elementsIAmContingentOn.toList).toSet - var argsRemaining = initialArgs + var argsRemaining = initialArgs while (!argsRemaining.isEmpty) { - state1 = sampleInState(argsRemaining.head, state1, universe, useObservation) + state1 = sampleInState(argsRemaining.head, state1, universe, cache) val newArgs = element.args.filter(!initialArgs.contains(_)) initialArgs = initialArgs ++ newArgs argsRemaining = argsRemaining.tail ++ newArgs @@ -88,10 +102,7 @@ object Forward { (state1, element.value) } } - element.value = (useObservation, element.observation) match { - case (true, Some(v)) => v - case _ => sampledValue.asInstanceOf[T] - } + element.value = sampledValue.asInstanceOf[T] state1 + element } } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala index e727243a..e4bf5915 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala @@ -57,10 +57,12 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc */ var debug = false - private def newState: State = State(Map(), Map(), 0.0, 0.0, scala.collection.mutable.Set()) + private def newState: State = State(Map(), Map(), 0.0, 0.0, scala.collection.mutable.Set(), List()) private val fastTargets = targets.toSet + protected var chainCache: Cache = new ChainCache(universe) + /* * We continually update the values of elements while making a proposal. In order to be able to undo it, we need to * store the old value. @@ -68,33 +70,24 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc * We keep track of which elements do not have their condition satisfied by the new proposal. */ private def attemptChange[T](state: State, elem: Element[T]): State = { - val newValue = { - // Don't generate a new value for an observed element because it won't agree with the observation - // For a compound element we can't do this because we have to condition the arguments by the - // probability of generating the correct value. - if (elem.observation.isEmpty || !elem.isInstanceOf[Atomic[_]]) elem.generateValue(elem.randomness) - else elem.observation.get - } + + // Don't generate a new value for an observed element because it won't agree with the observation + // For a compound element we can't do this because we have to condition the arguments by the + // probability of generating the correct value. + val newValue = if (elem.observation.isEmpty || !elem.isInstanceOf[Atomic[_]]) { + chainCache(elem) match { + case None => elem.generateValue(elem.randomness) + case Some(result) => + if (result.value == null) result.generate + result.value + } + } else elem.observation.get + // if an old value is already stored, don't overwrite it - val newOldValues = - if (state.oldValues contains elem) state.oldValues; else state.oldValues + (elem -> elem.value) - if (elem.value != newValue) { - val newDissatisfied = - if (elem.condition(newValue)) state.dissatisfied -= elem; else state.dissatisfied += elem - elem.value = newValue - State(newOldValues, state.oldRandomness, state.proposalProb, state.modelProb, newDissatisfied) - } else { - // We need to make sure to add the element to the dissatisfied set if its condition is not satisfied, - // even if the value has not changed, because we compare the dissatisfied set with the old dissatisfied set - // when deciding whether to accept the proposal. - val newDissatisfied = - if (elem.condition(newValue)) { - state.dissatisfied - elem - } else { - state.dissatisfied + elem - } - State(newOldValues, state.oldRandomness, state.proposalProb, state.modelProb, newDissatisfied) - } + val newOldValues = if (state.oldValues contains elem) state.oldValues; else state.oldValues + (elem -> elem.value) + val newDissatisfied = if (elem.condition(newValue)) state.dissatisfied -= elem; else state.dissatisfied += elem + elem.value = newValue + State(newOldValues, state.oldRandomness, state.proposalProb, state.modelProb, newDissatisfied, state.visitOrder :+ elem) } private def propose[T](state: State, elem: Element[T]): State = { @@ -112,7 +105,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc } val newProb = state.proposalProb + log(proposalProb) elem.randomness = randomness - State(state.oldValues, newOldRandomness, newProb, state.modelProb + log(modelProb), state.dissatisfied) + State(state.oldValues, newOldRandomness, newProb, state.modelProb + log(modelProb), state.dissatisfied, state.visitOrder) } val result = attemptChange(state1, elem) if (debug) println("old randomness = " + oldRandomness + @@ -143,7 +136,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc val newOldRandomness2 = if (newOldRandomness1 contains elem2) newOldRandomness1 else newOldRandomness1 + (elem2 -> oldRandomness2) - State(state.oldValues, newOldRandomness2, state.proposalProb, state.modelProb, state.dissatisfied) + State(state.oldValues, newOldRandomness2, state.proposalProb, state.modelProb, state.dissatisfied, state.visitOrder) } val state2 = attemptChange(state1, elem1) val result = attemptChange(state2, elem2) @@ -222,10 +215,10 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc if (currentStack.isEmpty && currentArgs.isEmpty && updateQ.isEmpty) state - else if (currentStack.isEmpty && currentArgs.isEmpty && updateQ.nonEmpty) { + else if (currentStack.isEmpty && currentArgs.isEmpty && updateQ.nonEmpty) { val argsRemaining = universe.uses(updateQ.head).intersect(updateQ.tail) updateMany(state, List(updateQ.head), argsRemaining.toSet, updateQ.tail -- argsRemaining) - } else if (currentStack.nonEmpty && currentArgs.isEmpty) { + } else if (currentStack.nonEmpty && currentArgs.isEmpty) { val newState = updateOne(state, currentStack.head) updateMany(newState, currentStack.tail, currentArgs, updateQ) } else { @@ -283,17 +276,20 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc protected def undo(state: State): Unit = { if (debug) println("Rejecting!\n") - state.oldValues foreach (setValue(_)) - state.oldRandomness foreach (setRandomness(_)) - - /* Have to call generateValue on chains after a rejection to restore the old resulting - * element. We can't do this above because we have to ensure the value of parent is restored before we - * do this. - */ - for ((elem, value) <- state.oldValues) { + + state.visitOrder.foreach { elem => elem match { - case c: Chain[_, _] => c.generateValue - case _ => + case c: Chain[_, _] => { + val result = + chainCache(c) match { + case Some(result) => c.value = result.value.asInstanceOf[c.Value] + case None => throw new AlgorithmException + } + } + case _ => { + if (state.oldRandomness.contains(elem)) elem.randomness = state.oldRandomness(elem).asInstanceOf[elem.Randomness] + elem.value = state.oldValues(elem).asInstanceOf[elem.Value] + } } } } @@ -322,7 +318,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc protected def mhStep(): State = { val newStateUnconstrained = proposeAndUpdate() val newState = State(newStateUnconstrained.oldValues, newStateUnconstrained.oldRandomness, - newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied) + newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied, newStateUnconstrained.visitOrder) if (decideToAccept(newState)) { accepts += 1 accept(newState) @@ -354,7 +350,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc protected def doInitialize(): Unit = { // Need to prime the universe to make sure all elements have a generated value - Forward(false)(universe) + chainCache = Forward(universe, chainCache) initConstrainedValues() dissatisfied = universe.conditionedElements.toSet filter (!_.conditionSatisfied) for { i <- 1 to burnIn } mhStep() @@ -383,7 +379,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc for { i <- 1 to numSamples } { val newStateUnconstrained = proposeAndUpdate() val state1 = State(newStateUnconstrained.oldValues, newStateUnconstrained.oldRandomness, - newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied) + newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied, newStateUnconstrained.visitOrder) if (decideToAccept(state1)) { accepts += 1 // collect results for the new state and restore the original state @@ -524,5 +520,6 @@ object MetropolisHastings { oldRandomness: Map[Element[_], Any], proposalProb: Double, modelProb: Double, - dissatisfied: scala.collection.mutable.Set[Element[_]]) + dissatisfied: scala.collection.mutable.Set[Element[_]], + visitOrder: List[Element[_]]) } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala index ed361e91..0c2e78c5 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala @@ -73,7 +73,7 @@ abstract class MetropolisHastingsAnnealer(universe: Universe, proposalScheme: Pr override protected def mhStep(): State = { val newStateUnconstrained = proposeAndUpdate() val newState = State(newStateUnconstrained.oldValues, newStateUnconstrained.oldRandomness, - newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied) + newStateUnconstrained.proposalProb, newStateUnconstrained.modelProb + computeScores, newStateUnconstrained.dissatisfied, newStateUnconstrained.visitOrder ) if (decideToAccept(newState)) { accepts += 1 accept(newState) @@ -114,7 +114,7 @@ abstract class MetropolisHastingsAnnealer(universe: Universe, proposalScheme: Pr } override def doInitialize(): Unit = { - Forward(false)(universe) + chainCache = Forward(universe, chainCache) initConstrainedValues() dissatisfied = universe.conditionedElements.toSet filter (!_.conditionSatisfied) currentEnergy = universe.constrainedElements.map(_.constraintValue).sum diff --git a/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala b/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala index d920cc9c..b297655c 100644 --- a/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala +++ b/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala @@ -32,7 +32,9 @@ import scala.collection.mutable.Set class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], cacheSize: Int, collection: ElementCollection) extends Deterministic[U](name, collection) { - def args: List[Element[_]] = if (active && resultElement != null) List(parent) ::: List(resultElement) else List(parent) + +// def args: List[Element[_]] = if (active && resultElement != null) List(parent) ::: List(resultElement) else List(parent) + def args: List[Element[_]] = List(parent) protected def cpd = fcn @@ -46,7 +48,7 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c /** * The current result element that arises from the current value of the parent. */ - var resultElement: Element[U] = _ + //var resultElement: Element[U] = _ /* Data structures for the Chain. The cache stores previously generated result elements. The Context data * structures store the elements that were created in this context. We also stored newly created elements @@ -56,16 +58,17 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c * contexts. When Chain gets the distribution over the child, it first pushes the context, and pops the context afterward, to mark * any generated elements as being generated in the context of this Chain. */ - lazy private[figaro] val cache: Map[T, Element[U]] = Map() - lazy private[figaro] val myMappedContextContents: Map[T, Set[Element[_]]] = Map() - lazy private[figaro] val elemInContext: Map[Element[_], T] = Map() + //lazy private[figaro] val cache: Map[T, Element[U]] = Map() + //lazy private[figaro] val myMappedContextContents: Map[T, Set[Element[_]]] = Map() + //lazy private[figaro] val elemInContext: Map[Element[_], T] = Map() - private var lastParentValue: T = _ + //private var lastParentValue: T = _ /* Must override clear temporary for Chains. We can never leave the chain in an uninitialized state. That is, * the chain MUST ALWAYS have a valid element to return. So when clearing temporaries we clear everything * except the current context. */ + /* override def clearContext() = { myMappedContextContents.keys.foreach(c => if (c != lastParentValue) resizeCache(c)) } @@ -84,12 +87,14 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c myMappedContextContents(elemInContext(e)) -= e elemInContext -= e } + * + */ def generateValue() = { if (parent.value == null) parent.generate() - val resultElem = get(parent.value) - if (resultElem.value == null) resultElem.generate() - resultElem.value + val resultElement = get(parent.value) + if (resultElement.value == null) resultElement.generate() + resultElement.value } /* Computes the new result. If the cache contains a VALID element for this parent value, then return the @@ -103,6 +108,7 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c /** * Get the distribution over the result corresponding to the given parent value. Takes care of all bookkeeping including caching. */ + /* def get(parentValue: T): Element[U] = { val lruParent = lastParentValue lastParentValue = parentValue @@ -125,12 +131,20 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c resultElement = newResult newResult } + * + */ + def get(parentValue: T): Element[U] = { + val result = getResult(parentValue) + universe.registerUses(this, result) + result + } /** * Get the distribution over the result corresponding to the given parent value. This call is UNCACHED, * meaning it will not be stored in the Chain's cache, and subsequent calls using the same parentValue * could return different elements. */ + /* def getUncached(parentValue: T): Element[U] = { if (lastParentValue == null || lastParentValue != parentValue) { myMappedContextContents.getOrElseUpdate(parentValue, Set()) @@ -140,6 +154,9 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c } resultElement } + * + */ + def getUncached(parentValue: T): Element[U] = get(parentValue) // All elements created in cpd will be created in this Chain's context with a subContext of parentValue private def getResult(parentValue: T): Element[U] = { @@ -152,6 +169,7 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c /* Current replacement scheme just drops the last element in the cache. The dropped element must be deactivated, * and removed from the context data structures. */ + /* protected def resizeCache(dropValue: T) = { cache -= dropValue if (myMappedContextContents.contains(dropValue)) { @@ -160,6 +178,8 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c myMappedContextContents -= dropValue } } + * + */ override def toString = "Chain(" + parent + ", " + cpd + ")" } diff --git a/Figaro/src/main/scala/com/cra/figaro/language/Element.scala b/Figaro/src/main/scala/com/cra/figaro/language/Element.scala index 0daed4d9..e02c7a34 100644 --- a/Figaro/src/main/scala/com/cra/figaro/language/Element.scala +++ b/Figaro/src/main/scala/com/cra/figaro/language/Element.scala @@ -163,9 +163,9 @@ abstract class Element[T](val name: Name[T], val collection: ElementCollection) throw new NoSuchElementException } else myDirectContextContents - private[language] def addContextContents(e: Element[_]): Unit = myDirectContextContents += e + private[figaro] def addContextContents(e: Element[_]): Unit = myDirectContextContents += e - private[language] def removeContextContents(e: Element[_]): Unit = myDirectContextContents -= e + private[figaro] def removeContextContents(e: Element[_]): Unit = myDirectContextContents -= e /** * Returns true if this element is temporary, that is, was created in the context of another element. diff --git a/Figaro/src/main/scala/com/cra/figaro/language/Universe.scala b/Figaro/src/main/scala/com/cra/figaro/language/Universe.scala index f86cec35..926016fc 100644 --- a/Figaro/src/main/scala/com/cra/figaro/language/Universe.scala +++ b/Figaro/src/main/scala/com/cra/figaro/language/Universe.scala @@ -170,8 +170,6 @@ class Universe(val parentElements: List[Element[_]] = List()) extends ElementCol private[language] def activate(element: Element[_]): Unit = { if (element.active) throw new IllegalArgumentException("Activating active element") -// if (element.args exists (!_.active)) -// throw new IllegalArgumentException("Attempting to activate element with inactive argument") element.args.filter(!_.active).foreach(activate(_)) myActiveElements.add(element) if (!element.isInstanceOf[Deterministic[_]]) myStochasticElements.add(element) @@ -182,8 +180,6 @@ class Universe(val parentElements: List[Element[_]] = List()) extends ElementCol } element.args foreach (registerUses(element, _)) element.active = true -// myRecursiveUsedBy.clear -// myRecursiveUses.clear } private[language] def deactivate(element: Element[_]): Unit = { diff --git a/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala b/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala index 83f5361c..8e8929a8 100644 --- a/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala @@ -104,7 +104,7 @@ abstract class Decision[T, U](name: Name[U], parent: Element[T], private var fcn // Have to nullify the last result even if parents the same since the function changed clearContext // Have to clear the last element in the cache since clearTempory always leaves an element in the cache - if (cache.nonEmpty) resizeCache(cache.last._1) + //if (cache.nonEmpty) resizeCache(cache.last._1) // Have to remove the expansion of the universe since it is out of data LazyValues.clear(universe) // Must regenerate a new value since the cache should never be empty diff --git a/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala b/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala new file mode 100644 index 00000000..94863463 --- /dev/null +++ b/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala @@ -0,0 +1,128 @@ +package com.cra.figaro.util + +import com.cra.figaro.language._ +import scala.collection.mutable.Map +import scala.collection.mutable.Set +import scala.collection.generic.Shrinkable + +/** + * Abstract class to manage caching of element generation for a universe. This class can be used + * by algorithms to manage caching of chains. + */ +abstract class Cache(universe: Universe) extends Shrinkable[Element[_]] { + + /** + * Return the next element from the generative process defined by element. If no process + * is found, return None + */ + def apply[T](element: Element[T]): Option[Element[T]] + + universe.register(this) + + /** + * Clear any caching + */ + def clear(): Unit +} + +/** A Cache class which performs no caching */ +class NoCache(universe: Universe) extends Cache(universe) { + def apply[T](element: Element[T]): Option[Element[T]] = None + def clear() = {} + def -=(element: Element[_]) = this +} + +/** + * A class which implements caching for caching and non-caching chains. + * + */ +class ChainCache(universe: Universe) extends Cache(universe) { + + val ccCache: Map[Element[_], Map[Any, Element[_]]] = Map() + val ccInvertedCache: Map[Element[_], Map[Element[_], Any]] = Map() + + val nccCache: Map[Element[_], List[(Any, Element[_], Set[Element[_]])]] = Map() + + def apply[T](element: Element[T]): Option[Element[T]] = { + element match { + case c: CachingChain[_, T] => { + doCachingChain(c) + } + case c: NonCachingChain[_, T] => { + doNonCachingChain(c) + } + case _ => None + } + } + + def doCachingChain[U, T](c: CachingChain[U, T]): Option[Element[T]] = { + val cachedElems = ccCache.getOrElseUpdate(c, Map()) + val cachedValue = cachedElems.get(c.parent.value) + if (!cachedValue.isEmpty) cachedValue.asInstanceOf[Option[Element[T]]] + else { + val result = c.get(c.parent.value) + cachedElems += (c.parent.value -> result) + val invertedElems = ccInvertedCache.getOrElseUpdate(result, Map()) + invertedElems += (c -> c.parent.value) + Some(result) + } + } + + def doNonCachingChain[U, T](c: NonCachingChain[U, T]): Option[Element[T]] = { + val nccElems = nccCache.getOrElse(c, List()) + if (nccElems.isEmpty) { + val result = c.get(c.parent.value) + nccCache += (c -> List((c.parent.value, result, Set()))) + Some(result) + } else if (c.parent.value == nccElems.head._1) Some(nccElems.head._2.asInstanceOf[Element[T]]) + else if (nccElems.size > 1 && c.parent.value == nccElems.last._1) { + val oldContext = c.directContextContents.clone -- nccElems.last._3 + val head = (nccElems.last._1, nccElems.last._2, Set[Element[_]]()) + val last = (nccElems.head._1, nccElems.head._2, oldContext) + //oldContext.foreach(c.removeContextContents(_)) + nccCache += (c -> List(head, last)) + Some(nccElems.last._2.asInstanceOf[Element[T]]) + } else { + if (nccElems.size == 2) universe.deactivate(nccElems.last._3) + val oldContext = c.directContextContents.clone + //oldContext.foreach(c.removeContextContents(_)) + val result = c.get(c.parent.value) + val head = (c.parent.value, result, Set[Element[_]]()) + val last = (nccElems.head._1, nccElems.head._2, oldContext) + nccCache += (c -> List(head, last)) + Some(result) + } + } + + /* + def set(element: Element[_], result: Element[_]): Unit = { + element match { + case c: CachingChain[_, _] => { + ccCache.getOrElseUpdate(c, Map()) += (c.parent.value -> result) + } + case c: NonCachingChain[_, _] => { + + } + case _ => () + } + } + * + */ + + def -=(element: Element[_]) = { + ccCache -= element + nccCache -= element + val invertValue = ccInvertedCache.get(element) + if (invertValue.nonEmpty) invertValue.get.foreach(e => ccCache(e._1) -= e._2) + ccInvertedCache -= element + this + } + + def clear() = { + ccCache.clear() + ccInvertedCache.clear() + nccCache.clear() + universe.deregister(this) + } +} + diff --git a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala index dd830831..03c84ea0 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala @@ -58,7 +58,8 @@ class OpenUniverseTest extends WordSpec with Matchers { (0.25, () => ProposalScheme(numSources)), (0.25, () => ProposalScheme(sources.items(random.nextInt(numSources.value)))), (0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).sourceNum)), - (0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).position.resultElement))) + (0.25, () => ProposalScheme.default)) + //(0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).position))) sample1.position.addCondition((y: Double) => y >= 0.5 && y < 0.8) sample2.position.addCondition((y: Double) => y >= 0.5 && y < 0.8) diff --git a/Figaro/src/test/scala/com/cra/figaro/test/language/ElementsTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/language/ElementsTest.scala index aff4d68d..eaf2442b 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/language/ElementsTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/language/ElementsTest.scala @@ -91,6 +91,8 @@ class ElementsTest extends WordSpec with Matchers { NonCachingChain(f1, fn).toString should equal("Chain(" + f1 + ", " + fn + ")") } + // No more caching, removed + /* "return a cached result" in { var sum = 0 def fn(b: Boolean) = { @@ -103,7 +105,11 @@ class ElementsTest extends WordSpec with Matchers { c.get(true) sum should equal(2) } + * + */ + // No more caching, removed + /* "call the CPD when the cache is full" in { Universe.createNew() var sum = 0 @@ -122,9 +128,29 @@ class ElementsTest extends WordSpec with Matchers { // see implementation of NonCachingChain and use of oldParentValue sum should be > (3) } + * + */ + "call the CPD for each Chain access" in { + Universe.createNew() + var sum = 0 + def fn(b: Int) = { + sum += 1 + Constant(b) + } + val f1 = Uniform(0, 1, 2) + val c = NonCachingChain(f1, fn _) + sum = 0 + c.get(0) + c.get(1) + c.get(2) + c.get(0) + sum should equal(4) + } } - "managing the context" should { + // No more local context of chains + /* + "managing the context" should { "store new elements in the correct subContext" in { Universe.createNew() val c = Chain(Flip(0.5), (b: Boolean) => if (b) Constant(0) else Constant(1)) @@ -135,7 +161,7 @@ class ElementsTest extends WordSpec with Matchers { c.myMappedContextContents(false).size should equal(1) c.elemInContext(c.myMappedContextContents(false).head) should equal(false) } - + "remove deactivated elements from context when resizing the cache" in { Universe.createNew() val c = NonCachingChain(Uniform(0, 1, 2), (b: Int) => Constant(b)) @@ -145,7 +171,7 @@ class ElementsTest extends WordSpec with Matchers { c.directContextContents.size should equal(2) c.elemInContext.size should equal(2) } - + "remove deactivated elements from context when removing temporaries" in { Universe.createNew() val c = CachingChain(com.cra.figaro.library.atomic.discrete.Uniform(0, 10), (b: Int) => Constant(b)) @@ -156,7 +182,7 @@ class ElementsTest extends WordSpec with Matchers { c.directContextContents.size should equal(1) c.elemInContext.size should equal(1) } - + "only remove elements defined in subContext" in { Universe.createNew() def fcn(b: Int) = { @@ -176,6 +202,8 @@ class ElementsTest extends WordSpec with Matchers { Universe.universe.contextContents(c) forall (_.active) should equal(true) } } + * + */ } "A chain with two parents" when { @@ -213,22 +241,18 @@ class ElementsTest extends WordSpec with Matchers { "evaluate the CPD each time get is called" in { Universe.createNew() var sum = 0 - def fn(b: (Boolean, Boolean)) = { + def fn(b1: Boolean, b2: Boolean): Element[Boolean] = { sum += 1 - Constant(b._1 && b._2) + Constant(b1 && b2) } val f1 = Flip(0.5) val f2 = Flip(0.5) f1.set(true) f2.set(false) - val c = new Chain("", ^^(f1, f2), fn, 1, Universe.universe) - //val c = NonCachingChain(f1, f2, fn _) - sum = 0 + val c = Chain(f1, f2, fn) c.get(true, true) c.get(false, false) c.get(true, true) - // either 3 or 4 depending on whether the value on initialization is true or false - // see implementation of NonCachingChain and use of oldParentValue sum should equal(3) } } @@ -265,22 +289,7 @@ class ElementsTest extends WordSpec with Matchers { CachingChain(f1, f2, fn).toString should equal("Chain(Apply(" + f1 + ", " + f2 + ", " + fn + "), )") } //((((((((((80 * .50) * 1.08) + 82.4 * .50) * 1.08) + 84.8 * .50) * 1.08) + 87.3 * .50) * 1.08) + 89.9 * .50) * 1.08) - "evaluate the CPD only once for each input" in { - var sum = 0 - def fn(b1: Boolean, b2: Boolean) = { - sum += 1 - Constant(b1 && b2) - } - val f1 = Flip(0.5) - val f2 = Flip(0.5) - f1.set(true) - f2.set(true) - val c = CachingChain(f1, f2, fn) - c.get(true, true) - c.get(false, false) - c.get(true, true) - sum should equal(2) - } + } } From 61aa26a26687d2c342a2c1b74b7311309c860745 Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Thu, 25 Jun 2015 13:57:16 -0400 Subject: [PATCH 03/18] Changes for issue #296. Chains no longer implement caching. Caching can be accomplished by using a com.cra.figaro.util.Cache class, which can be used as part of an algorithm that needs caching. Chains are purely functional now, meaning they do not cache the last result element. Repeated calls to a chain with the same parent value will call the chain function repeatedly. Various changes to other files to support this. --- .../decision/DecisionImportance.scala | 1 + .../filtering/ParParticleFilter.scala | 10 +- .../algorithm/filtering/ParticleFilter.scala | 22 +-- .../sampling/MetropolisHastings.scala | 3 +- .../sampling/ProbEvidenceSampler.scala | 6 +- .../scala/com/cra/figaro/language/Chain.scala | 135 ++---------------- .../figaro/library/decision/Decision.scala | 15 +- .../scala/com/cra/figaro/util/Cache.scala | 103 +++++++++---- .../figaro/test/algorithm/AlgorithmTest.scala | 2 +- .../lazyfactored/LazyValuesTest.scala | 2 +- .../algorithm/sampling/ImportanceTest.scala | 4 + .../sampling/ParImportanceTest.scala | 3 + .../test/example/OpenUniverseTest.scala | 2 +- .../test/library/compound/CompoundTest.scala | 9 +- .../test/library/decision/DecisionTest.scala | 14 -- .../com/cra/figaro/test/util/CacheTest.scala | 112 +++++++++++++++ 16 files changed, 245 insertions(+), 198 deletions(-) create mode 100644 Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/decision/DecisionImportance.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/decision/DecisionImportance.scala index b8aef04e..ef8a0fe6 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/decision/DecisionImportance.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/decision/DecisionImportance.scala @@ -75,6 +75,7 @@ abstract class DecisionImportance[T, U] private (override val universe: Universe // override doSample so can update the local utilities override protected def doSample(): Unit = { val s = sample() + universe.clearTemporaries() totalWeight = logSum(s._1, totalWeight) allWeightsSeen foreach (updateWeightSeenForTarget(s, _)) allUtilitiesSeen foreach (updateWeightSeenForTargetNoLog((math.exp(s._1) * utilitySum, s._2), _)) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala index 937d8075..b8bf787f 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala @@ -3,6 +3,8 @@ package com.cra.figaro.algorithm.filtering import com.cra.figaro.language._ import scala.collection.parallel.ParSeq import com.cra.figaro.algorithm.filtering.ParticleFilter.WeightedParticle +import com.cra.figaro.util.ChainCache +import com.cra.figaro.util.Cache /** * A parallel one-time particle filter. Distributes the work of generating particles at each time step over a specified @@ -42,7 +44,7 @@ class ParOneTimeParticleFilter(static: () => Universe, initial: () => Universe, * @param windows the UniverseWindows to sample from * @param weightedParticleCreator a function that generates a WeightedParticle, given a UniverseWindow and an index */ - private def genParticles(windows: Seq[UniverseWindow], weightedParticleCreator: (UniverseWindow, Int) => WeightedParticle): Seq[WeightedParticle] = { + private def genParticles(windows: Seq[(UniverseWindow, Cache)], weightedParticleCreator: ((UniverseWindow, Cache), Int) => WeightedParticle): Seq[WeightedParticle] = { val parWindows = windows.par val particles = parWindows zip indices flatMap { case(window, (start, end)) => (start to end) map { i => weightedParticleCreator(window, i) } @@ -58,13 +60,15 @@ class ParOneTimeParticleFilter(static: () => Universe, initial: () => Universe, def run(): Unit = { windows = genInitialWindows() - val particles = genParticles(windows, (w, _) => initialWeightedParticle(w.static, w.current)) + val windowWithCaches = windows.map(w => (w, new ChainCache(w.current))) + val particles = genParticles(windowWithCaches, (w, _) => initialWeightedParticle(w._1.static, w._1.current, w._2)) doTimeStep(particles) } def advanceTime(evidence: Seq[NamedEvidence[_]] = List()): Unit = { val newWindows = advanceUniverseWindows(windows) - val particles = genParticles(newWindows, (w, i) => addWeightedParticle(evidence, i, w)) + val newWindowsWithCaches = newWindows.map(w => (w, new ChainCache(w.current))) + val particles = genParticles(newWindowsWithCaches, (w, i) => addWeightedParticle(evidence, i, w._1, w._2)) doTimeStep(particles) windows = newWindows } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala index 09b4b345..fb906e26 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala @@ -42,7 +42,7 @@ trait ParticleFilter { val beliefState: ParticleFilter.BeliefState = Array.fill(numParticles)(null) protected var logProbEvidence: Double = 0.0 - + /** * Returns the expectation of the element referred to by the reference * under the given function at the current time point. @@ -76,8 +76,8 @@ trait ParticleFilter { * TODO: previous state could be replaced by the static universe (or a universe window) */ - protected def makeWeightedParticle(previousState: State, currentUniverse: Universe): ParticleFilter.WeightedParticle = { - Forward(currentUniverse) + protected def makeWeightedParticle(previousState: State, currentUniverse: Universe, cache: Cache): ParticleFilter.WeightedParticle = { + Forward(currentUniverse, cache) // avoiding recursion // satisfied if all conditioned elements are satisfied @@ -91,7 +91,7 @@ trait ParticleFilter { val snapshot = new Snapshot snapshot.store(currentUniverse) - val state = new State(snapshot, previousState.static) + val state = new State(snapshot, previousState.static) (weight, state) } @@ -118,21 +118,21 @@ trait ParticleFilter { logProbEvidence = logProbEvidence + scala.math.log(sum / numParticles) } - protected def addWeightedParticle(evidence: Seq[NamedEvidence[_]], index: Int, universes: UniverseWindow): ParticleFilter.WeightedParticle = { + protected def addWeightedParticle(evidence: Seq[NamedEvidence[_]], index: Int, universes: UniverseWindow, cache: Cache): ParticleFilter.WeightedParticle = { val previousState = beliefState(index) previousState.dynamic.restore(universes.previous) previousState.static.restore(universes.static) universes.current.assertEvidence(evidence) - val result = makeWeightedParticle(previousState, universes.current) + val result = makeWeightedParticle(previousState, universes.current, cache) result } - protected def initialWeightedParticle(static: Universe, current: Universe): ParticleFilter.WeightedParticle = { + protected def initialWeightedParticle(static: Universe, current: Universe, cache: Cache): ParticleFilter.WeightedParticle = { Forward(static) val staticSnapshot = new Snapshot staticSnapshot.store(static) val state = new State(new Snapshot, staticSnapshot) - makeWeightedParticle(state, current) + makeWeightedParticle(state, current, cache) } /* @@ -196,7 +196,8 @@ class OneTimeParticleFilter(static: Universe = new Universe(), initial: Universe * Begin the particle filter, determining the initial distribution. */ def run(): Unit = { - doTimeStep((i: Int) => initialWeightedParticle(static, currentUniverse)) + val chainCache = new ChainCache(currentUniverse) + doTimeStep((i: Int) => initialWeightedParticle(static, currentUniverse, chainCache)) } /** @@ -206,7 +207,8 @@ class OneTimeParticleFilter(static: Universe = new Universe(), initial: Universe val currentWindow = new UniverseWindow(previousUniverse, currentUniverse, static) val newWindow = advanceUniverse(currentWindow, transition) - doTimeStep((i: Int) => addWeightedParticle(evidence, i, newWindow)) + val chainCache = new ChainCache(newWindow.current) + doTimeStep((i: Int) => addWeightedParticle(evidence, i, newWindow, chainCache)) previousUniverse = newWindow.previous currentUniverse = newWindow.current } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala index e4bf5915..7bc3bdae 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala @@ -279,8 +279,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc state.visitOrder.foreach { elem => elem match { - case c: Chain[_, _] => { - val result = + case c: Chain[_, _] => { chainCache(c) match { case Some(result) => c.value = result.value.asInstanceOf[c.Value] case None => throw new AlgorithmException diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala index 99b8ddb1..235ae611 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala @@ -17,6 +17,8 @@ import com.cra.figaro.algorithm._ import com.cra.figaro.language._ import scala.language.existentials import com.cra.figaro.util.logSum +import com.cra.figaro.util.ChainCache +import com.cra.figaro.util.Cache /** * Algorithm that computes probability of evidence using forward sampling. @@ -35,12 +37,14 @@ abstract class ProbEvidenceSampler(override val universe: Universe, override val totalWeight = 0.0 } + + protected var chainCache: Cache = new ChainCache(universe) /* * To protect against underflow, the probabilities are computed in log-space. */ protected def doSample(): Unit = { - Forward(universe) + Forward(universe, chainCache) //Some values in log constraints may be negative infinity. val weight = universe.constrainedElements.map(_.constraintValue).sum diff --git a/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala b/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala index b297655c..a865dadc 100644 --- a/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala +++ b/Figaro/src/main/scala/com/cra/figaro/language/Chain.scala @@ -22,22 +22,19 @@ import scala.collection.mutable.Set * A Chain(parent, fcn) represents the process that first generates a value for the parent, then * applies fcn to get a new Element, and finally generates a value from that new Element. * - * Chain is the common base class for caching and non-caching chains. - * All chains have a cache, whose size is specified by the cacheSize argument. - * When a parent value is encountered, first the cache is checked to see if the result element is known. - * If it is not, the resulting element is generated from scratch by calling fcn. + * Chain is the common base class for caching and non-caching chains. There is no functional difference + * between caching and non-caching chains. Algorithms can use the distinction to implement a caching procedure. * * @param parent The parent element of the chain */ -class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], cacheSize: Int, collection: ElementCollection) +class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], collection: ElementCollection) extends Deterministic[U](name, collection) { - -// def args: List[Element[_]] = if (active && resultElement != null) List(parent) ::: List(resultElement) else List(parent) - def args: List[Element[_]] = List(parent) + + def args: List[Element[_]] = List(parent) protected def cpd = fcn - + private[figaro] val chainFunction = fcn /** @@ -45,51 +42,6 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c */ type ParentType = T - /** - * The current result element that arises from the current value of the parent. - */ - //var resultElement: Element[U] = _ - - /* Data structures for the Chain. The cache stores previously generated result elements. The Context data - * structures store the elements that were created in this context. We also stored newly created elements - * in a subContext, which is based on the value of the parent. - * Because Elements might be stored in sets or maps by algorithms, we need a way to allow the elements to be removed - * from the set or map so they can be garbage collected. Universe provides a way to achieve this through the use of - * contexts. When Chain gets the distribution over the child, it first pushes the context, and pops the context afterward, to mark - * any generated elements as being generated in the context of this Chain. - */ - //lazy private[figaro] val cache: Map[T, Element[U]] = Map() - //lazy private[figaro] val myMappedContextContents: Map[T, Set[Element[_]]] = Map() - //lazy private[figaro] val elemInContext: Map[Element[_], T] = Map() - - //private var lastParentValue: T = _ - - /* Must override clear temporary for Chains. We can never leave the chain in an uninitialized state. That is, - * the chain MUST ALWAYS have a valid element to return. So when clearing temporaries we clear everything - * except the current context. - */ - /* - override def clearContext() = { - myMappedContextContents.keys.foreach(c => if (c != lastParentValue) resizeCache(c)) - } - - /* Override context control for chain data structures */ - override def directContextContents: Set[Element[_]] = if (!active) { - throw new NoSuchElementException - } else Set(myMappedContextContents.values.flatten.toList: _*) - - override private[figaro] def addContextContents(e: Element[_]) = { - myMappedContextContents(lastParentValue) += e - elemInContext += (e -> lastParentValue) - } - - override private[figaro] def removeContextContents(e: Element[_]) = { - myMappedContextContents(elemInContext(e)) -= e - elemInContext -= e - } - * - */ - def generateValue() = { if (parent.value == null) parent.generate() val resultElement = get(parent.value) @@ -97,66 +49,16 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c resultElement.value } - /* Computes the new result. If the cache contains a VALID element for this parent value, then return the - * the element. Otherwise, we need to create one. First, we create an entry in the ContextContents since - * any elements created in this context will be stored in the subContext of parentValue. Then, if the cache - * is full, we resize the cache to make room for a new result element. Apply the function of the chain, and - * stored the new result in the cache. If the element created is a new element, then registers it uses - * in the Universe. - */ - /** - * Get the distribution over the result corresponding to the given parent value. Takes care of all bookkeeping including caching. + * Get the distribution over the result corresponding to the given parent value. */ - /* - def get(parentValue: T): Element[U] = { - val lruParent = lastParentValue - lastParentValue = parentValue - val cacheValue = cache.get(parentValue) - - val newResult = - if (!cacheValue.isEmpty && cacheValue.get.active) cacheValue.get - else { - myMappedContextContents += (parentValue -> Set()) - if (cache.size >= cacheSize && cacheValue.isEmpty) { - val dropValue = if (cache.last._1 != lruParent) cache.last._1 else cache.head._1 - resizeCache(dropValue) - } - val result = getResult(parentValue) - cache += (parentValue -> result) - universe.registerUses(this, result) - result - } - - resultElement = newResult - newResult - } - * - */ def get(parentValue: T): Element[U] = { val result = getResult(parentValue) universe.registerUses(this, result) result } - /** - * Get the distribution over the result corresponding to the given parent value. This call is UNCACHED, - * meaning it will not be stored in the Chain's cache, and subsequent calls using the same parentValue - * could return different elements. - */ - /* - def getUncached(parentValue: T): Element[U] = { - if (lastParentValue == null || lastParentValue != parentValue) { - myMappedContextContents.getOrElseUpdate(parentValue, Set()) - lastParentValue = parentValue - resultElement = getResult(parentValue) - universe.registerUses(this, resultElement) - } - resultElement - } - * - */ - def getUncached(parentValue: T): Element[U] = get(parentValue) + private[figaro] def getUncached(parentValue: T): Element[U] = get(parentValue) // All elements created in cpd will be created in this Chain's context with a subContext of parentValue private def getResult(parentValue: T): Element[U] = { @@ -165,22 +67,7 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c universe.popContext(this) result } - - /* Current replacement scheme just drops the last element in the cache. The dropped element must be deactivated, - * and removed from the context data structures. - */ - /* - protected def resizeCache(dropValue: T) = { - cache -= dropValue - if (myMappedContextContents.contains(dropValue)) { - universe.deactivate(myMappedContextContents(dropValue)) - elemInContext --= myMappedContextContents(dropValue) - myMappedContextContents -= dropValue - } - } - * - */ - + override def toString = "Chain(" + parent + ", " + cpd + ")" } @@ -188,13 +75,13 @@ class Chain[T, U](name: Name[U], val parent: Element[T], fcn: T => Element[U], c * A NonCachingChain is an implementation of Chain with a single element cache. */ class NonCachingChain[T, U](name: Name[U], parent: Element[T], cpd: T => Element[U], collection: ElementCollection) - extends Chain(name, parent, cpd, 2, collection) + extends Chain(name, parent, cpd, collection) /** * A CachingChain is an implementation of Chain with a 1000 element cache. */ class CachingChain[T, U](name: Name[U], parent: Element[T], cpd: T => Element[U], collection: ElementCollection) - extends Chain(name, parent, cpd, Int.MaxValue, collection) + extends Chain(name, parent, cpd, collection) object NonCachingChain { /** Create a NonCaching chain of 1 argument. */ diff --git a/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala b/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala index 8e8929a8..a746cf3d 100644 --- a/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/decision/Decision.scala @@ -44,8 +44,7 @@ import scala.collection.mutable.Set * noncaching uses an approximate policy algorithm (based on kNN) and should be used for continuous * parents or discrete parents with a very large range. */ -abstract class Decision[T, U](name: Name[U], parent: Element[T], private var fcn: T => Element[U], cacheSize: Int, collection: ElementCollection) - extends Chain(name, parent, fcn, cacheSize, collection) with PolicyMaker[T, U] { +trait Decision[T, U] extends Chain[T, U] with PolicyMaker[T, U] { /** * The parent type. @@ -57,6 +56,8 @@ abstract class Decision[T, U](name: Name[U], parent: Element[T], private var fcn */ type DValue = this.Value + var fcn: T => Element[U] + /** * The decision function. fcn is declared as a var and can change depending on the policy. */ @@ -110,14 +111,14 @@ abstract class Decision[T, U](name: Name[U], parent: Element[T], private var fcn // Must regenerate a new value since the cache should never be empty generateValue() } - } + /** * Abstract class for a NonCachingDecision. It is abstract because makePolicy has not been defined yet. */ -abstract class NonCachingDecision[T, U](name: Name[U], parent: Element[T], fcn: T => Element[U], collection: ElementCollection) - extends Decision(name, parent, fcn, 1, collection) { +abstract class NonCachingDecision[T, U](name: Name[U], parent: Element[T], var fcn: T => Element[U], collection: ElementCollection) + extends NonCachingChain(name, parent, fcn, collection) with Decision[T, U] { override def toString = "NonCachingDecision(" + parent + ", " + this.cpd + ")" } @@ -125,8 +126,8 @@ abstract class NonCachingDecision[T, U](name: Name[U], parent: Element[T], fcn: /** * Abstract class for a CachingDecision. It is abstract because makePolicy has not been defined yet. */ -abstract class CachingDecision[T, U](name: Name[U], parent: Element[T], fcn: T => Element[U], collection: ElementCollection) - extends Decision(name, parent, fcn, 1000, collection) { +abstract class CachingDecision[T, U](name: Name[U], parent: Element[T], var fcn: T => Element[U], collection: ElementCollection) + extends CachingChain(name, parent, fcn, collection) with Decision[T, U] { override def toString = "CachingDecision(" + parent + ", " + this.cpd + ")" } diff --git a/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala b/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala index 94863463..3d1457ff 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala @@ -5,22 +5,22 @@ import scala.collection.mutable.Map import scala.collection.mutable.Set import scala.collection.generic.Shrinkable -/** +/** * Abstract class to manage caching of element generation for a universe. This class can be used - * by algorithms to manage caching of chains. - */ + * by algorithms to manage caching of chains. + */ abstract class Cache(universe: Universe) extends Shrinkable[Element[_]] { /** * Return the next element from the generative process defined by element. If no process - * is found, return None + * is found, return None */ - def apply[T](element: Element[T]): Option[Element[T]] + def apply[T](element: Element[T]): Option[Element[T]] universe.register(this) /** - * Clear any caching + * Clear any caching */ def clear(): Unit } @@ -34,15 +34,39 @@ class NoCache(universe: Universe) extends Cache(universe) { /** * A class which implements caching for caching and non-caching chains. - * + * + * For caching chains, the result of the Chain's function is cached for each value of the parent element + * that is queried. This cache is infinitely large. + * + * For non-caching chains, we only "cache" two resulting elements of the chain. The cache is actually + * a 2-element stack, where the top of the stack represents the most recent element for the chain, and the + * bottom of the stack represents the last element (and parent value) used. This is primarily to benefit + * MH; if a proposal is rejected, we want to switch a chain back to where it was without much overhead. + * */ class ChainCache(universe: Universe) extends Cache(universe) { - val ccCache: Map[Element[_], Map[Any, Element[_]]] = Map() - val ccInvertedCache: Map[Element[_], Map[Element[_], Any]] = Map() + /* Caching chain cache that maps from an element to a map of parent values and resulting elements */ + private[figaro] val ccCache: Map[Element[_], Map[Any, Element[_]]] = Map() - val nccCache: Map[Element[_], List[(Any, Element[_], Set[Element[_]])]] = Map() + /* The inverted cache. This maps from result elements back to the chain that uses them. This is needed + * to properly clean up deactivated elements + */ + private[figaro] val ccInvertedCache: Map[Element[_], Map[Element[_], Any]] = Map() + /* + * The non-caching chain "cache". This is a map from elements to a list of: + * (parent value, result element, Set of elements created in the context of the parent value) + * The Set is needed since once a parent value falls off the stack, we have to clear all the elements + * created in the context of that parent value or else we will have memory leaks + */ + private[figaro] val nccCache: Map[Element[_], List[(Any, Element[_], Set[Element[_]])]] = Map() + + /** + * Retrieve any cached element generated from the current value of the supplied element. Returns None if + * the element does not generate another element. + * + */ def apply[T](element: Element[T]): Option[Element[T]] = { element match { case c: CachingChain[_, T] => { @@ -55,11 +79,18 @@ class ChainCache(universe: Universe) extends Cache(universe) { } } - def doCachingChain[U, T](c: CachingChain[U, T]): Option[Element[T]] = { + /* + * Retrieves an element from the caching chain cache, or inserts a new one if none is found for + * the value of this element + */ + private def doCachingChain[U, T](c: CachingChain[U, T]): Option[Element[T]] = { + val cachedElems = ccCache.getOrElseUpdate(c, Map()) val cachedValue = cachedElems.get(c.parent.value) if (!cachedValue.isEmpty) cachedValue.asInstanceOf[Option[Element[T]]] else { + // If the value of the element is not found in the cache, generate a new element by calling the chain, + // add it to the cache and the inverted value, and return val result = c.get(c.parent.value) cachedElems += (c.parent.value -> result) val invertedElems = ccInvertedCache.getOrElseUpdate(result, Map()) @@ -68,24 +99,43 @@ class ChainCache(universe: Universe) extends Cache(universe) { } } - def doNonCachingChain[U, T](c: NonCachingChain[U, T]): Option[Element[T]] = { + /* + * Retrieves an element for a non-caching chain. This is not really a cache but rather, + * for each element, a 2-deep stack is maintained that has the current result element of the chain + * at the top of the stack, and the last result element at the bottom of the stack. This is for use in + * MH. When a proposal is made, the chain may change its value. In such a case, we don't want to lose + * the current result element in case the proposal is rejected, so it is moved to the back of the stack. + * If the proposal is reject, the chain is regenerated and the old element is restored to the top of the stack. + * + */ + private def doNonCachingChain[U, T](c: NonCachingChain[U, T]): Option[Element[T]] = { val nccElems = nccCache.getOrElse(c, List()) + if (nccElems.isEmpty) { + // If no element has been stored for this chain, generate a value and store it in the stack for this element val result = c.get(c.parent.value) nccCache += (c -> List((c.parent.value, result, Set()))) Some(result) - } else if (c.parent.value == nccElems.head._1) Some(nccElems.head._2.asInstanceOf[Element[T]]) - else if (nccElems.size > 1 && c.parent.value == nccElems.last._1) { + } else if (c.parent.value == nccElems.head._1) { + // If the current value of the parent matches the value at the top of the stack, return the element at the top of the stack + Some(nccElems.head._2.asInstanceOf[Element[T]]) + } else if (nccElems.size > 1 && c.parent.value == nccElems.last._1) { + // If the current value matches the value at the bottom of the stack, then we need to do a swap; move the back to the front + // and the front to the back + + // Store the elements in the context of the top of the stack. This is the current context of the chain minus the context + // of the value at the back of the stack val oldContext = c.directContextContents.clone -- nccElems.last._3 + // swap the head and last positions val head = (nccElems.last._1, nccElems.last._2, Set[Element[_]]()) val last = (nccElems.head._1, nccElems.head._2, oldContext) - //oldContext.foreach(c.removeContextContents(_)) nccCache += (c -> List(head, last)) Some(nccElems.last._2.asInstanceOf[Element[T]]) } else { + // Otherwise, we have a new parent value. In this case, we drop the bottom of the stack, deactivate the element + // and the context of the value we are dropping. if (nccElems.size == 2) universe.deactivate(nccElems.last._3) val oldContext = c.directContextContents.clone - //oldContext.foreach(c.removeContextContents(_)) val result = c.get(c.parent.value) val head = (c.parent.value, result, Set[Element[_]]()) val last = (nccElems.head._1, nccElems.head._2, oldContext) @@ -94,21 +144,9 @@ class ChainCache(universe: Universe) extends Cache(universe) { } } - /* - def set(element: Element[_], result: Element[_]): Unit = { - element match { - case c: CachingChain[_, _] => { - ccCache.getOrElseUpdate(c, Map()) += (c.parent.value -> result) - } - case c: NonCachingChain[_, _] => { - - } - case _ => () - } - } - * - */ - + /** + * Removes an element from the cache. This is needed to properly clean up elements as they are deactivated. + */ def -=(element: Element[_]) = { ccCache -= element nccCache -= element @@ -118,6 +156,9 @@ class ChainCache(universe: Universe) extends Cache(universe) { this } + /** + * Clears the cache of all stored elements. + */ def clear() = { ccCache.clear() ccInvertedCache.clear() diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/AlgorithmTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/AlgorithmTest.scala index 8c1f29d2..8d016d71 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/AlgorithmTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/AlgorithmTest.scala @@ -165,7 +165,7 @@ class AlgorithmTest extends WordSpec with Matchers { val init = Universe.universe.activeElements.size alg.kill val after = Universe.universe.activeElements.size - after should equal (400*3) + after should equal (400*2) } } diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/lazyfactored/LazyValuesTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/lazyfactored/LazyValuesTest.scala index a0be8248..ce232719 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/lazyfactored/LazyValuesTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/lazyfactored/LazyValuesTest.scala @@ -94,7 +94,7 @@ class LazyValuesTest extends WordSpec with Matchers { values(elem1, 1) values(elem1, 1) values(elem1, 0) - a should equal (1) + a should equal (2) } "use the old result when called twice on the same universe" in { diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala index 696178a2..65e87d3b 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala @@ -475,6 +475,8 @@ class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { i.kill() } + /* These tests are no longer valid. Since there is a hidden dependency, we can't support this */ + /* "resample elements inside class defined in a chain" in { Universe.createNew() class temp { @@ -500,6 +502,8 @@ class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { //alg.probability(b, true) should be (0.9 +- .01) } + * + */ "not suffer from stack overflow with small probability of success" taggedAs (Performance) in { Universe.createNew() diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ParImportanceTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ParImportanceTest.scala index afb7042a..e412bb02 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ParImportanceTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ParImportanceTest.scala @@ -258,6 +258,7 @@ class ParImportanceTest extends WordSpec with Matchers with PrivateMethodTester i2.kill() } + /* Test is not valid "resample elements inside class defined in a chain" in { val gen = () => { val universe = Universe.createNew() @@ -273,6 +274,8 @@ class ParImportanceTest extends WordSpec with Matchers with PrivateMethodTester alg.probability("b", true) should be(0.9 +- .01) alg.kill } + * + */ "not suffer from stack overflow with small probability of success" taggedAs (Performance) in { val gen = () => { diff --git a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala index 03c84ea0..819736ca 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala @@ -84,7 +84,7 @@ class OpenUniverseTest extends WordSpec with Matchers { val totalProbSame = (0.0 /: (1 to limitNumSources))(_ + probSame(_)) val totalProbDifferent = (0.0 /: (1 to limitNumSources))(_ + probDifferent(_)) val answer = totalProbSame / (totalProbSame + totalProbDifferent) - val alg = MetropolisHastings(2000000, chooseScheme, 50000, equal) + val alg = MetropolisHastings(200000, chooseScheme, 5000, equal) alg.start() alg.probability(equal, true) should be(answer +- 0.02) alg.kill diff --git a/Figaro/src/test/scala/com/cra/figaro/test/library/compound/CompoundTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/library/compound/CompoundTest.scala index 56fc4c53..25902e32 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/library/compound/CompoundTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/library/compound/CompoundTest.scala @@ -69,14 +69,17 @@ class CompoundTest extends WordSpec with Matchers { "not generate a consequent that's not needed when sampling" in { Universe.createNew() var count = 0 - def makeElem() = { + def makeElem1() = { + Constant(1) + } + def makeElem2() = { count += 1 Constant(1) } - val e = If(Constant(true), makeElem(), makeElem()) + val e = If(Constant(true), makeElem1(), makeElem2()) val alg = Importance(100, e) alg.start() - count should equal (1) + count should equal (0) alg.kill() } } diff --git a/Figaro/src/test/scala/com/cra/figaro/test/library/decision/DecisionTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/library/decision/DecisionTest.scala index 92daa5e2..e58abebf 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/library/decision/DecisionTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/library/decision/DecisionTest.scala @@ -45,20 +45,6 @@ class DecisionTest extends WordSpec with Matchers { } - "A NonCachingDecision" should { - - "remove previous elements when parent value changes" in { - val U = Universe.createNew() - val u1 = Uniform((0 until 5): _*) - val d1 = NonCachingDecision(u1, 5 until 10) - d1.setPolicy((i: Int) => Constant(i)) - val prev = d1.get(u1.value) - val curr = d1.get((u1.value + 1) % 5) - U.uses(d1) should not contain (prev) - prev.active should be(false) - } - - } } diff --git a/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala new file mode 100644 index 00000000..28335975 --- /dev/null +++ b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala @@ -0,0 +1,112 @@ +/* + * MultiSetTest.scala + * Needs description + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ + +package com.cra.figaro.test.util + +import org.scalatest.Matchers +import org.scalatest.WordSpec +import com.cra.figaro.util._ +import com.cra.figaro.language._ +import com.cra.figaro.library.atomic.continuous.Uniform +import com.cra.figaro.library.compound.If + +class CacheTest extends WordSpec with Matchers { + "A chain cache" should { + "correctly retrieve cache elements for caching chains" in { + val u = Universe.createNew() + val cc = new ChainCache(u) + var sum = 0 + def fn(b: Boolean) = { + sum += 1 + Constant(b) + } + val f = Flip(0.5) + val c = CachingChain(f, fn) + f.value = true; cc(c) + f.value = false; cc(c) + f.value = true; cc(c) + sum should equal(2) + } + + "keep the stack at maximum of two for non-caching chains" in { + val u = Universe.createNew() + val cc = new ChainCache(u) + val f = Uniform(0.0, 1.0) + val c = Chain(f, (d: Double) => Constant(d)) + for { _ <- 0 until 10 } { + f.generate + cc(c) + } + cc.nccCache(c).size should equal(2) + } + + "remove deactivated elements from the cache" in { + val u = Universe.createNew() + val cc = new ChainCache(u) + val a1 = Flip(0.1) + val a2 = Flip(0.2) + val s = Flip(0.5) + val c = If(s, a1, a2) + s.value = true; cc(c) + s.value = false; cc(c) + cc.ccCache(c).size should equal(2) + a1.deactivate + cc.ccCache(c).size should equal(1) + } + + "correctly clear the context of elements removed from the stack" in { + val u = Universe.createNew() + val cc = new ChainCache(u) + def fn(d: Double) = { + Flip(d); Flip(d); Flip(d) + } + val f = Uniform(0.0, 1.0) + val c = Chain(f, fn) + + f.generate; cc(c) + c.directContextContents.size should equal(3) + f.generate; cc(c) + c.directContextContents.size should equal(6) + f.generate; cc(c) + c.directContextContents.size should equal(6) + u.activeElements.size should equal(8) + } + + "correctly clear the caches when clearing temporaries" in { + val u = Universe.createNew() + val cc = new ChainCache(u) + def fn(d: Double) = { + Flip(d); Flip(d); Flip(d) + } + val f = Uniform(0.0, 1.0) + val perm = Constant(false) + val fl = Flip(0.5) + val c2 = CachingChain(fl, (b: Boolean) => { + if (b) perm + else { + Flip(f) + } + }) + + for {_ <- 0 until 100} { + f.generate + fl.generate + cc(c2) + } + u.clearTemporaries + u.activeElements.size should equal(4) + cc.ccCache(c2).size should equal(1) + cc.ccInvertedCache(perm).size should equal (1) + } + } +} From 528d2a2d7152464ddeab3f700f25204a9b904c90 Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Tue, 30 Jun 2015 21:32:21 -0400 Subject: [PATCH 04/18] New likelihood weighting for Importance sampling, forward sampling, and particle filtering. Various performance changes to support removal of caching from chains. Also pushed a new proposal method. When one proposes a chain, the proposal is automatically forwarded to the result element of the chain. --- .../filtering/ParParticleFilter.scala | 11 +- .../algorithm/filtering/ParticleFilter.scala | 57 +++-- .../algorithm/sampling/ElementSampler.scala | 1 + .../figaro/algorithm/sampling/Forward.scala | 100 +++----- .../algorithm/sampling/Importance.scala | 224 ++---------------- .../sampling/LikelihoodWeighter.scala | 205 ++++++++++++++++ .../sampling/MetropolisHastings.scala | 16 +- .../sampling/MetropolisHastingsAnnealer.scala | 2 +- .../sampling/ProbEvidenceSampler.scala | 21 +- .../algorithm/sampling/WeightedSampler.scala | 3 +- .../sampling/LikelihoodWeighter.scala | 133 ----------- .../com/cra/figaro/library/cache/Cache.scala | 40 ++++ .../cache/MHCache.scala} | 49 +--- .../figaro/library/cache/PermanentCache.scala | 81 +++++++ .../algorithm/sampling/ImportanceTest.scala | 20 +- .../sampling/LikelihoodWeighterTest.scala | 109 +++++++++ .../test/example/OpenUniverseTest.scala | 3 +- .../com/cra/figaro/test/util/CacheTest.scala | 21 +- 18 files changed, 575 insertions(+), 521 deletions(-) create mode 100644 Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala delete mode 100644 Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala create mode 100644 Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala rename Figaro/src/main/scala/com/cra/figaro/{util/Cache.scala => library/cache/MHCache.scala} (85%) create mode 100644 Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala create mode 100644 Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala index b8bf787f..521e65ce 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala @@ -3,8 +3,9 @@ package com.cra.figaro.algorithm.filtering import com.cra.figaro.language._ import scala.collection.parallel.ParSeq import com.cra.figaro.algorithm.filtering.ParticleFilter.WeightedParticle -import com.cra.figaro.util.ChainCache -import com.cra.figaro.util.Cache +import com.cra.figaro.library.cache.PermanentCache +import com.cra.figaro.library.cache.Cache +import com.cra.figaro.algorithm.sampling.LikelihoodWeighter /** * A parallel one-time particle filter. Distributes the work of generating particles at each time step over a specified @@ -44,7 +45,7 @@ class ParOneTimeParticleFilter(static: () => Universe, initial: () => Universe, * @param windows the UniverseWindows to sample from * @param weightedParticleCreator a function that generates a WeightedParticle, given a UniverseWindow and an index */ - private def genParticles(windows: Seq[(UniverseWindow, Cache)], weightedParticleCreator: ((UniverseWindow, Cache), Int) => WeightedParticle): Seq[WeightedParticle] = { + private def genParticles(windows: Seq[(UniverseWindow, LikelihoodWeighter)], weightedParticleCreator: ((UniverseWindow, LikelihoodWeighter), Int) => WeightedParticle): Seq[WeightedParticle] = { val parWindows = windows.par val particles = parWindows zip indices flatMap { case(window, (start, end)) => (start to end) map { i => weightedParticleCreator(window, i) } @@ -60,14 +61,14 @@ class ParOneTimeParticleFilter(static: () => Universe, initial: () => Universe, def run(): Unit = { windows = genInitialWindows() - val windowWithCaches = windows.map(w => (w, new ChainCache(w.current))) + val windowWithCaches = windows.map(w => (w, new LikelihoodWeighter(w.current, new PermanentCache(w.current)))) val particles = genParticles(windowWithCaches, (w, _) => initialWeightedParticle(w._1.static, w._1.current, w._2)) doTimeStep(particles) } def advanceTime(evidence: Seq[NamedEvidence[_]] = List()): Unit = { val newWindows = advanceUniverseWindows(windows) - val newWindowsWithCaches = newWindows.map(w => (w, new ChainCache(w.current))) + val newWindowsWithCaches = newWindows.map(w => (w, new LikelihoodWeighter(w.current, new PermanentCache(w.current)))) val particles = genParticles(newWindowsWithCaches, (w, i) => addWeightedParticle(evidence, i, w._1, w._2)) doTimeStep(particles) windows = newWindows diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala index fb906e26..72953399 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParticleFilter.scala @@ -16,6 +16,9 @@ package com.cra.figaro.algorithm.filtering import com.cra.figaro.algorithm.sampling._ import com.cra.figaro.language._ import com.cra.figaro.util._ +import com.cra.figaro.library.cache.PermanentCache +import com.cra.figaro.library.cache.Cache +import com.cra.figaro.algorithm.sampling.LikelihoodWeighter /** * An abstract class of particle filters. @@ -35,14 +38,14 @@ import com.cra.figaro.util._ * @param transition The transition function that returns a new universe from a static and previous universe, respectively. */ trait ParticleFilter { - + val numParticles: Int /** The belief about the state of the system at the current point in time. */ val beliefState: ParticleFilter.BeliefState = Array.fill(numParticles)(null) protected var logProbEvidence: Double = 0.0 - + /** * Returns the expectation of the element referred to by the reference * under the given function at the current time point. @@ -76,22 +79,18 @@ trait ParticleFilter { * TODO: previous state could be replaced by the static universe (or a universe window) */ - protected def makeWeightedParticle(previousState: State, currentUniverse: Universe, cache: Cache): ParticleFilter.WeightedParticle = { - Forward(currentUniverse, cache) - // avoiding recursion - - // satisfied if all conditioned elements are satisfied - val satisfied = currentUniverse.conditionedElements.forall { x => x.conditionSatisfied } - - //multiply weights together in log space if satisfied - val weight = - if (satisfied) { - math.exp(currentUniverse.constrainedElements.foldLeft(0.0)((b,a) => b + a.constraintValue)) - } else 0.0 + protected def makeWeightedParticle(previousState: State, currentUniverse: Universe, lw: LikelihoodWeighter): ParticleFilter.WeightedParticle = { + + val weight = try { + math.exp(lw.computeWeight(currentUniverse.activeElements)) + } catch { + case Importance.Reject => 0.0 + } val snapshot = new Snapshot snapshot.store(currentUniverse) - val state = new State(snapshot, previousState.static) + val state = new State(snapshot, previousState.static) + currentUniverse.clearTemporaries (weight, state) } @@ -118,21 +117,21 @@ trait ParticleFilter { logProbEvidence = logProbEvidence + scala.math.log(sum / numParticles) } - protected def addWeightedParticle(evidence: Seq[NamedEvidence[_]], index: Int, universes: UniverseWindow, cache: Cache): ParticleFilter.WeightedParticle = { + protected def addWeightedParticle(evidence: Seq[NamedEvidence[_]], index: Int, universes: UniverseWindow, lw: LikelihoodWeighter): ParticleFilter.WeightedParticle = { val previousState = beliefState(index) previousState.dynamic.restore(universes.previous) previousState.static.restore(universes.static) universes.current.assertEvidence(evidence) - val result = makeWeightedParticle(previousState, universes.current, cache) + val result = makeWeightedParticle(previousState, universes.current, lw) result } - protected def initialWeightedParticle(static: Universe, current: Universe, cache: Cache): ParticleFilter.WeightedParticle = { + protected def initialWeightedParticle(static: Universe, current: Universe, lw: LikelihoodWeighter): ParticleFilter.WeightedParticle = { Forward(static) val staticSnapshot = new Snapshot staticSnapshot.store(static) val state = new State(new Snapshot, staticSnapshot) - makeWeightedParticle(state, current, cache) + makeWeightedParticle(state, current, lw) } /* @@ -159,7 +158,7 @@ trait ParticleFilter { logProbEvidence } - /** + /** * The computed probability of evidence. */ def probEvidence(): Double = { @@ -179,10 +178,10 @@ trait ParticleFilter { */ class OneTimeParticleFilter(static: Universe = new Universe(), initial: Universe, transition: (Universe, Universe) => Universe, val numParticles: Int) extends Filtering(static, initial, transition) with ParticleFilter with OneTimeFiltering { - + var currentUniverse: Universe = initial var previousUniverse: Universe = _ - + private def doTimeStep(weightedParticleCreator: Int => ParticleFilter.WeightedParticle) { val weightedParticles = for { i <- 0 until numParticles } yield weightedParticleCreator(i) @@ -196,19 +195,19 @@ class OneTimeParticleFilter(static: Universe = new Universe(), initial: Universe * Begin the particle filter, determining the initial distribution. */ def run(): Unit = { - val chainCache = new ChainCache(currentUniverse) - doTimeStep((i: Int) => initialWeightedParticle(static, currentUniverse, chainCache)) + val lw = new LikelihoodWeighter(currentUniverse, new PermanentCache(currentUniverse)) + doTimeStep((i: Int) => initialWeightedParticle(static, currentUniverse, lw)) } /** * Advance the filtering one time step, conditioning on the given evidence at the new time point. */ def advanceTime(evidence: Seq[NamedEvidence[_]] = List()): Unit = { - + val currentWindow = new UniverseWindow(previousUniverse, currentUniverse, static) val newWindow = advanceUniverse(currentWindow, transition) - val chainCache = new ChainCache(newWindow.current) - doTimeStep((i: Int) => addWeightedParticle(evidence, i, newWindow, chainCache)) + val lw = new LikelihoodWeighter(newWindow.current, new PermanentCache(newWindow.current)) + doTimeStep((i: Int) => addWeightedParticle(evidence, i, newWindow, lw)) previousUniverse = newWindow.previous currentUniverse = newWindow.current } @@ -258,14 +257,14 @@ object ParticleFilter { /** Weighted particles, consisting of a weight and a state. */ type WeightedParticle = (Double, State) - + /** Reference to parallel implementation. */ def par = ParParticleFilter } /** - * A class representing a single window in time, with a current universe, a previous universe, + * A class representing a single window in time, with a current universe, a previous universe, * and a static universe. */ class UniverseWindow(val previous: Universe, val current: Universe, val static: Universe) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala index 23a45c0b..423725bd 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala @@ -26,6 +26,7 @@ abstract class ElementSampler(target: Element[_]) extends BaseUnweightedSampler( def sample(): (Boolean, Sample) = { Forward(target) + universe.clearTemporaries (true, Map[Element[_], Any](target -> target.value)) } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala index 75247fc0..8ae576e2 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala @@ -14,9 +14,14 @@ package com.cra.figaro.algorithm.sampling import com.cra.figaro.language._ -import com.cra.figaro.util.ChainCache -import com.cra.figaro.util.Cache -import com.cra.figaro.util.NoCache +import com.cra.figaro.library.cache.Cache +import com.cra.figaro.library.cache.NoCache +import com.cra.figaro.algorithm.sampling.LikelihoodWeighter + + +class ForwardWeighter(universe: Universe, cache: Cache) extends LikelihoodWeighter(universe, cache) { + override def rejectionAction() = () +} /** * A forward sampler that generates a state by generating values for elements, making sure to generate all the @@ -26,84 +31,33 @@ object Forward { /** * Sample the universe by generating a value for each element of the universe. Return a cache object. */ - def apply(universe: Universe): Cache = { + def apply(universe: Universe): Double = { apply(universe, new NoCache(universe)) } /** * Sample the universe by generating a value for each element of the universe, and provide a cache object. Return a cache object. */ - def apply(universe: Universe, cache: Cache): Cache = { - // avoiding recursion - var state = Set[Element[_]]() - var elementsRemaining = universe.activeElements - while (!elementsRemaining.isEmpty) { - if (elementsRemaining.head.active) state = sampleInState(elementsRemaining.head, state, universe, cache) - elementsRemaining = elementsRemaining.tail - } - cache + def apply(universe: Universe, cache: Cache): Double = { + val lw = new ForwardWeighter(universe, cache) + try { + lw.computeWeight(universe.activeElements) + } catch { + case Importance.Reject => Double.NegativeInfinity + } } - def apply[T](element: Element[T]) = { - val noCache = new NoCache(element.universe) - sampleInState(element, Set[Element[_]](), element.universe, noCache) - noCache - } - - private type State = Set[Element[_]] - - /* - * To allow this algorithm to be used for dependent universes, we make sure elements in a different universe are not - * sampled. + /** + * Sample only part of the model originating from a single element */ - private def sampleInState[T](element: Element[T], state: State, universe: Universe, cache: Cache): State = { - if (element.universe != universe || (state contains element)) state - else { - val (state1, sampledValue) = { - element match { - case d: Dist[_, _] => - val state1 = - d match { - case dc: CompoundDist[_] => - // avoiding recursion - var resultState = state - var probsRemaining = dc.probs - while (!probsRemaining.isEmpty) { - resultState = sampleInState(probsRemaining.head, resultState, universe, cache) - probsRemaining = probsRemaining.tail - } - resultState - case _ => state - } - val rand = d.generateRandomness() - val index = d.selectIndex(rand) - val state2 = sampleInState(d.outcomeArray(index), state1, universe, cache) - (state2, d.finishGeneration(index)) - case c: Chain[_, _] => - val state1 = sampleInState(c.parent, state, universe, cache) - val result = cache(c) match { - case Some(r) => r - case _ => c.get(c.parent.value) - } - val state2 = sampleInState(result, state1, universe, cache) - (state2, result.value) - case _ => - // avoiding recursion - var state1 = state - var initialArgs = (element.args ::: element.elementsIAmContingentOn.toList).toSet - var argsRemaining = initialArgs - while (!argsRemaining.isEmpty) { - state1 = sampleInState(argsRemaining.head, state1, universe, cache) - val newArgs = element.args.filter(!initialArgs.contains(_)) - initialArgs = initialArgs ++ newArgs - argsRemaining = argsRemaining.tail ++ newArgs - } - element.generate - (state1, element.value) - } - } - element.value = sampledValue.asInstanceOf[T] - state1 + element - } + def apply[T](element: Element[T]): Double = { + val noCache = new NoCache(element.universe) + val lw = new ForwardWeighter(element.universe, noCache) + try { + lw.computeWeight(List(element)) + } catch { + case Importance.Reject => Double.NegativeInfinity + } } + } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Importance.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Importance.scala index a3c07376..3ab2225d 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Importance.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Importance.scala @@ -21,6 +21,8 @@ import scala.collection.mutable.{ Set, Map } import com.cra.figaro.experimental.particlebp.AutomaticDensityEstimator import com.cra.figaro.algorithm.factored.ParticleGenerator import com.cra.figaro.algorithm.sampling.parallel.ParImportance +import com.cra.figaro.algorithm.sampling.LikelihoodWeighter +import com.cra.figaro.library.cache.PermanentCache /** * Importance samplers. @@ -29,55 +31,12 @@ abstract class Importance(universe: Universe, targets: Element[_]*) extends WeightedSampler(universe, targets: _*) { import Importance.State -/* - * Likelihood weighting works by propagating observations through Dists and Chains - * to the variables they depend on. If we don't make sure we sample those Dists and - * Chains first, we may end up sampling those other elements without the correct - * observations. To avoid this, we keep track of all these dependencies. - * The dependencies map contains all the elements that could propagate an - * observation to any given element. - * Note: the dependencies map is only concerned with active elements that are present - * at the beginning of sampling (even though we get a new active elements list each sample). - * Temporary elements will always be created after the element that could propagate - * an observation to them, because that propagation has to go through a permanent - * element. - * Therefore, we can generate the dependencies map once before all the samples are generated. - */ - - // Calling Values on a continuous element requires that a ParticleGenerator exist in the universe. - // For a continuous element, we don't actually want to generate any dependencies, so we have the - // ParticleGenerator return zero sample. - ParticleGenerator(universe, new AutomaticDensityEstimator, 1, 0) - private val dependencies = scala.collection.mutable.Map[Element[_], Set[Element[_]]]() - private def makeDependencies() = { - for { - element <- universe.activeElements - } { - element match { - case d: Dist[_,_] => - for { o <- d.outcomes } { dependencies += o -> (dependencies.getOrElse(o, Set()) + d) } - //In principle, we should create a dependency from the result element of a chain to - //the chain. If the result element is a permanent element, this could matter. - //However, in most relevant cases (e.g., compound versions of atomic elements), the - //outcome element will be temporary and automatically generated after the chain, - //so we don't need the dependency. On the other hand, creating the dependency requires - //a call to values which is slow and dangerous. - // - //Unfortunately, the above intuition doesn't hold. The EM with importance tests fail - //to terminate unless we add dependencies to chains. We're still searching for a better - //way to determine the possible outcome elements than to call values on the parent. - case c: CachingChain[_,_] => - val outcomes = Values(universe)(c.parent).map(c.get(_)) - for { o <- outcomes } { dependencies += o -> (dependencies.getOrElse(o, Set()) + c) } - case _ => () - } - } - } - makeDependencies() + val lw = new LikelihoodWeighter(universe, new PermanentCache(universe)) private var numRejections = 0 private var logSuccessWeight = 0.0 private var numSamples = 0 + def getSamples() = numSamples override protected def resetCounts() { super.resetCounts() @@ -98,15 +57,15 @@ abstract class Importance(universe: Universe, targets: Element[_]*) val activeElements = universe.activeElements val resultOpt: Option[Sample] = try { - val state = State() - activeElements.foreach(e => if (e.active) sampleOne(state, e, None)) + val weight = lw.computeWeight(activeElements) val bindings = targets map (elem => elem -> elem.value) - Some((state.weight, Map(bindings: _*))) + Some((weight, Map(bindings: _*))) } catch { case Importance.Reject => None } + universe.clearTemporaries() resultOpt match { case Some(x) => logSuccessWeight = logSum(logSuccessWeight, x._1) @@ -118,150 +77,11 @@ abstract class Importance(universe: Universe, targets: Element[_]*) } } - /* - * Sample the value of an element. If it has already been assigned in the state, the current assignment is - * used and the state is unchanged. Also, an element in a different universe is not sampled; instead its state is - * used directly - * - * This is made private[figaro] to allow easy testing - */ - private[figaro] def sampleOne[T](state: State, element: Element[T], observation: Option[T]): T = { - /* - * We have to make sure to sample any elements this element depends on first so we can get the right - * observation for this element. - */ - dependencies.getOrElse(element, Set()).filter(!state.assigned.contains(_)).foreach(sampleOne(state, _, None)) - if (element.universe != universe || (state.assigned contains element)) { - element.value - } - else { - state.assigned += element - sampleFresh(state, element, observation) - } - } - - /* - * Sample a fresh value of an element, assuming it has not yet been assigned a value in the state. This sampling - * takes into account the condition and constraint on the element. If the condition is violated, the entire sampling - * process is rejected. This function returns the state including the new assignment to the element with the weight - * of the constraint multiplied in. - */ - private def sampleFresh[T](state: State, element: Element[T], observation: Option[T]): T = { - val fullObservation = (observation, element.observation) match { - case (None, None) => None - case (Some(obs), None) => Some(obs) - case (None, Some(obs)) => Some(obs) - case (Some(obs1), Some(obs2)) if obs1 == obs2 => Some(obs1) - case _ => throw Importance.Reject // incompatible observations - } - val value: T = - if (fullObservation.isEmpty || !element.isInstanceOf[HasDensity[_]]) { - val result = sampleValue(state, element, fullObservation) - if (!element.condition(result)) throw Importance.Reject - result - } else { - // Optimize the common case of an observation on an atomic element. - // This partially implements likelihood weighting by clamping the element to its - // desired value and multiplying the weight by the density of the value. - // This can dramatically reduce the number of rejections. - val obs = fullObservation.get - - sampleArgs(element, state: State, Set[Element[_]](element.args:_*)) - - // I'm not quite sure why we have to call sampleValue here when we're about to set the value of this element to obs. - // If I remove this call, the test "should correctly resample an element's arguments when the arguments change during samples" - // fails. - sampleValue(state, element, Some(obs)) - val density = element.asInstanceOf[HasDensity[T]].density(obs) - state.weight += math.log(density) - obs - } - element.value = value - state.weight += element.constraint(value) - value - } - - /* - * Sample the value of an element according to its generative model, without considering the condition or constraint. - * Since sampling the value of this element might also involve sampling the values of related elements, the state - * must be updated and returned. - * - * Most elements can simply be handled by sampling values for the arguments and generating values for this - * element. Dist is an exception, because not all the outcomes need to be generated, but we only know which one - * after we have sampled the randomness of the Dist. For this reason, we write special code to handle Dists. - * For Chain, we also write special code to avoid calling get twice. - * - * We propagate observations on Chains and Dists to their possible outcome elements. This ensures that instead of - * sampling these elements and then checking whether the observation is satisfied, we set their values to the - * required ones. This implements likelihood weighting and leads to faster convergence of the algorithm. - */ - private def sampleValue[T](state: State, element: Element[T], observation: Option[T]): T = { - element match { - case d: Dist[_, _] => - d match { - case dc: CompoundDist[_] => dc.probs foreach (sampleOne(state, _, None)) - case _ => () - } - val rand = d.generateRandomness() - val index = d.selectIndex(rand) - sampleOne(state, d.outcomeArray(index), observation) - d.value = d.finishGeneration(index) - d.value - case c: Chain[_, _] => - val parentValue = sampleOne(state, c.parent, None) - val next = c.get(parentValue) - c.value = sampleOne(state, next, observation) - c.value - case f: CompoundFlip => - val probValue = sampleOne(state, f.prob, None) - observation match { - case Some(true) => - state.weight += math.log(probValue) - true - case Some(false) => - state.weight += math.log(1 - probValue) - false - case _ => - val result = random.nextDouble() < probValue - f.value = result - result - } - case f: ParameterizedFlip => - val probValue = sampleOne(state, f.parameter, None) - - observation match { - case Some(true) => - true - case Some(false) => - false - case _ => - val result = random.nextDouble() < probValue - f.value = result - result - } - - case _ => - val args = (element.args ::: element.elementsIAmContingentOn.toList) - sampleArgs(element, state: State, Set[Element[_]](args:_*)) - element.randomness = element.generateRandomness() - element.value = element.generateValue(element.randomness) - element.value - } - } - - @tailrec - private def sampleArgs(element: Element[_], state: State, args: Set[Element[_]]): Unit = { - args foreach (sampleOne(state, _, None)) - val newArgs = element.args.filter(!state.assigned.contains(_)) - if (newArgs.nonEmpty) - sampleArgs(element, state, Set[Element[_]](newArgs: _*)) - } - - /** + /** * The computed probability of evidence. */ def logProbEvidence: Double = { - logSuccessWeight - Math.log(numSamples + numRejections) + logSuccessWeight - Math.log(numSamples + numRejections) } } @@ -292,20 +112,20 @@ object Importance { new Importance(universe, targets: _*) with OneTimeProbQuerySampler with ProbEvidenceQuery { val numSamples = myNumSamples - /** - * Use one-time sampling to compute the probability of the given named evidence. - * Takes the conditions and constraints in the model as part of the model definition. - * This method takes care of creating and running the necessary algorithms. - */ - override def probabilityOfEvidence(evidence: List[NamedEvidence[_]]): Double = { - val logPartition = logProbEvidence - universe.assertEvidence(evidence) - if (active) kill() - start() - Math.exp(logProbEvidence - logPartition) - } + /** + * Use one-time sampling to compute the probability of the given named evidence. + * Takes the conditions and constraints in the model as part of the model definition. + * This method takes care of creating and running the necessary algorithms. + */ + override def probabilityOfEvidence(evidence: List[NamedEvidence[_]]): Double = { + val logPartition = logProbEvidence + universe.assertEvidence(evidence) + if (active) kill() + start() + Math.exp(logProbEvidence - logPartition) + } - } + } /** * Use IS to compute the probability that the given element satisfies the given predicate. diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala new file mode 100644 index 00000000..dec31672 --- /dev/null +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala @@ -0,0 +1,205 @@ +package com.cra.figaro.algorithm.sampling + +import com.cra.figaro.language._ +import scala.annotation.tailrec +import com.cra.figaro.library.cache.Cache +import scala.collection.mutable.Set + +/* + * Likelihood weighting works by propagating observations through Dists and Chains + * to the variables they depend on. If we don't make sure we sample those Dists and + * Chains first, we may end up sampling those other elements without the correct + * observations. To avoid this, we keep track of all these dependencies. + * The dependencies map contains all the elements that could propagate an + * observation to any given element. + * To avoid calling values on Chains, we dynamically build the list of dependencies. + * If we encounter the result element of a chain before it's parent, we redo the sampling + * on the result and undo any weight associated with that element. + * + */ + +/** + * A class that implements sampling via likelihood weighting on a set of elements. + */ +class LikelihoodWeighter(universe: Universe, cache: Cache) { + + /* Stores the dependencies between elements for likelihood weighting */ + private[figaro] val dependencies = scala.collection.mutable.Map[Element[_], Set[Element[_]]]() + universe.register(dependencies) + + /** + * Sample each element in the list of elements and compute their likelihood weight + */ + def computeWeight(elementsToVisit: List[Element[_]]): Double = { + // Do all dependencies first, then no need to check for them in the main traversal loop + val visited: Set[Element[_]] = Set() + val dependencyWeights = traverse(List(), dependencies.values.flatten.toList, 0.0, visited) + val remaining = elementsToVisit.filterNot(visited.contains(_)) + traverse(List(), remaining, dependencyWeights, visited) + } + + /* + * Traverse the elements in generative order, and return the weight + */ + @tailrec + private[figaro] final def traverse(currentStack: List[(Element[_], Option[_], Option[Element[_]])], + elementsToVisit: List[Element[_]], currentWeight: Double, visited: Set[Element[_]]): Double = { + + // If everything is empty, just return the weight + if (elementsToVisit.isEmpty && currentStack.isEmpty) { + currentWeight + + } // If the current stack is empty, take the head of the elements to visit as the next element + else if (currentStack.isEmpty) { + traverse(List((elementsToVisit.head, getObservation(elementsToVisit.head, None), None)), elementsToVisit.tail, currentWeight, visited) + + } // If the head of the stack has already been visited, is not active, or is in another universe, we don't need to do anything, go to the next element + else if (!currentStack.head._1.active || visited.contains(currentStack.head._1) || currentStack.head._1.universe != universe) { + traverse(currentStack.tail, elementsToVisit, currentWeight, visited += currentStack.head._1) + + } // Otherwise, we need to process the top element on the stack + else { + val (currElem, currObs, currResult) = currentStack.head + + currElem match { + case dist: Dist[_, _] => + val parents = dist match { + case dc: CompoundDist[_] => dc.probs.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None), None)) + case _ => List() + } + + if (parents.nonEmpty) { + traverse(parents ::: currentStack, elementsToVisit, currentWeight, visited) + } else { + val rand = dist.generateRandomness() + val index = dist.selectIndex(rand) + val resultElement = if (currResult.isEmpty) dist.outcomeArray(index) else currResult.get + val nextHead = List((resultElement, getObservation(resultElement, currObs), None), (currElem, None, Some(resultElement))) + + if (visited.contains(resultElement) && currObs.nonEmpty) { + // we did this in the wrong order, and have to repropagate the result for likelihood weighting, and add it to the dependency map so we don't do this incorrectly next time + val elementsToRedo = findDependentElements(dist, resultElement) + val newWeight = (currentWeight /: elementsToRedo)((c: Double, n: Element[_]) => undoWeight(c, n)) + traverse(nextHead ::: currentStack.tail, elementsToRedo.toList ::: elementsToVisit, newWeight, visited --= elementsToRedo) + } else if (!visited.contains(resultElement)) { + traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited) + } else { + dist.value = resultElement.value.asInstanceOf[dist.Value] + dependencies.getOrElseUpdate(resultElement, Set()) += dist + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit, nextWeight, visited += currElem) + } + } + + case chain: Chain[_, _] => + if (!visited.contains(chain.parent)) { + traverse((chain.parent, getObservation(chain.parent, None), None) :: currentStack, elementsToVisit, currentWeight, visited) + } else { + val next = if (currResult.isEmpty) cache(chain).get else currResult.get + val nextHead = List((next, getObservation(next, currObs), None), (currElem, None, Some(next))) + if (visited.contains(next) && currObs.nonEmpty) { + // we did this in the wrong order, and have to repropagate the result for likelihood weighting + val elementsToRedo = findDependentElements(chain, next) + val newWeight = (currentWeight /: elementsToRedo)((c: Double, n: Element[_]) => undoWeight(c, n)) + traverse(nextHead ::: currentStack.tail, elementsToRedo.toList ::: elementsToVisit, newWeight, visited --= elementsToRedo) + } else if (!visited.contains(next)) { + traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited) + } else { + chain match { + case cc: CachingChain[_, _] => if (!next.isTemporary) dependencies.getOrElseUpdate(next, Set()) += chain + case _ => () + } + chain.value = next.value.asInstanceOf[chain.Value] + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit, nextWeight, visited += currElem) + } + } + case _ => + val args = (currElem.args ::: currElem.elementsIAmContingentOn.toList) + // Find all the arguments of the element that have not been visited + val remainingArgs = args.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None), None)) + // if there are args unvisited, push those args to the top of the stack + if (remainingArgs.nonEmpty) { + traverse(remainingArgs ::: currentStack, elementsToVisit, currentWeight, visited) + } else { + // else, we can now process this element and move on to the next item + currElem.randomness = currElem.generateRandomness() + currElem.value = currElem.generateValue(currElem.randomness) + val nextWeight = computeNextWeight(currentWeight, currElem, currObs) + traverse(currentStack.tail, elementsToVisit, nextWeight, visited += currElem) + } + } + } + } + + + /* + * Finds the set of elements that need to be resampled when the likelihood weighting went in the wrong order + */ + private def findDependentElements(elem: Element[_], result: Element[_]) = { + val chainUsedBy = universe.usedBy(elem) + elem + val resultUseBy = universe.usedBy(result) + result + resultUseBy -- chainUsedBy + } + + /* + * Get the observation on an element, merging with any propagated observation from likelihood weighting + */ + protected def getObservation(element: Element[_], observation: Option[_]): Option[Any] = { + (observation, element.observation) match { + case (None, None) => None + case (Some(obs), None) => Some(obs) + case (None, Some(obs)) => Some(obs) + case (Some(obs1), Some(obs2)) if obs1 == obs2 => Some(obs1) + case _ => { // incompatible observations + rejectionAction() + None + } + } + } + + /* + * Compute the current weight of the model by incorporating the weight of the current sampled element + * If there is no observation on the element, the weight is the current weight plus the constraint on the element. + * If there is a condition on the element (not an observation), then we throw a rejection if the condition is not met. + * + * If there is an observation on the element, we implement likelihood weighting. If the element has a density + * function, we add the log density to the current weight. If it doesn't have a density, we check to see if + * it satisfies the observation + */ + private[figaro] def computeNextWeight(currentWeight: Double, element: Element[_], obs: Option[_]): Double = { + val nextWeight = if (obs.isEmpty) { + if (!element.condition(element.value)) rejectionAction() + currentWeight + } else { + element match { + case f: CompoundFlip => { + element.value = obs.get.asInstanceOf[element.Value] + if (obs.get.asInstanceOf[Boolean]) currentWeight + math.log(f.prob.value) + else currentWeight + math.log(1 - f.prob.value) + } + case e: HasDensity[_] => { + element.value = obs.get.asInstanceOf[element.Value] + val density = element.asInstanceOf[HasDensity[element.Value]].density(obs.asInstanceOf[Option[element.Value]].get) + currentWeight + math.log(density) + } + case _ => { + if (!element.condition(element.value)) rejectionAction() + currentWeight + } + } + } + nextWeight + element.constraint(element.value) + } + + /* Action to take on a rejection. By default it throws an Importance.Reject exception, but this can be overriden for another behavior */ + protected def rejectionAction(): Unit = throw Importance.Reject + + /* + * Undo the application of this elements weight if we did likelihood weighting in the wrong order + */ + private def undoWeight(weight: Double, elem: Element[_]) = weight - computeNextWeight(0.0, elem, elem.observation) + +} + + diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala index 7bc3bdae..369c9bf4 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastings.scala @@ -20,6 +20,7 @@ import scala.collection.mutable.Map import scala.language.existentials import scala.math.log import scala.annotation.tailrec +import com.cra.figaro.library.cache._ /** * Metropolis-Hastings samplers. @@ -61,7 +62,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc private val fastTargets = targets.toSet - protected var chainCache: Cache = new ChainCache(universe) + protected var chainCache: Cache = new MHCache(universe) /* * We continually update the values of elements while making a proposal. In order to be able to undo it, we need to @@ -162,11 +163,11 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc case FinalScheme(elem) => propose(state, elem()) case TypedScheme(first, rest) => val firstElem = first() - val state1 = propose(state, firstElem) + val state1 = propose(state, proposeChainCheck(firstElem)) continue(state1, rest(firstElem.value)) case UntypedScheme(first, rest) => val firstElem = first() - val state1 = propose(state, firstElem) + val state1 = propose(state, proposeChainCheck(firstElem)) continue(state1, rest) case ds: DisjointScheme => val (probs, schemes) = ds.choices.toList.unzip @@ -174,9 +175,14 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc runStep(state, choice()) case SwitchScheme(first, rest) => val (elem1, elem2) = first() - val state1 = switch(state, elem1, elem2) + val state1 = switch(state, proposeChainCheck(elem1), proposeChainCheck(elem2)) continue(state1, rest) } + + private def proposeChainCheck(elem: Element[_]): Element[_] = { + val e = chainCache(elem) + if (e.isEmpty) elem else e.get.asInstanceOf[Element[_]] + } protected def runScheme(): State = runStep(newState, proposalScheme) @@ -349,7 +355,7 @@ abstract class MetropolisHastings(universe: Universe, proposalScheme: ProposalSc protected def doInitialize(): Unit = { // Need to prime the universe to make sure all elements have a generated value - chainCache = Forward(universe, chainCache) + Forward(universe, chainCache) initConstrainedValues() dissatisfied = universe.conditionedElements.toSet filter (!_.conditionSatisfied) for { i <- 1 to burnIn } mhStep() diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala index 0c2e78c5..fb1b8975 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/MetropolisHastingsAnnealer.scala @@ -114,7 +114,7 @@ abstract class MetropolisHastingsAnnealer(universe: Universe, proposalScheme: Pr } override def doInitialize(): Unit = { - chainCache = Forward(universe, chainCache) + Forward(universe, chainCache) initConstrainedValues() dissatisfied = universe.conditionedElements.toSet filter (!_.conditionSatisfied) currentEnergy = universe.constrainedElements.map(_.constraintValue).sum diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala index 235ae611..471034c0 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbEvidenceSampler.scala @@ -17,8 +17,7 @@ import com.cra.figaro.algorithm._ import com.cra.figaro.language._ import scala.language.existentials import com.cra.figaro.util.logSum -import com.cra.figaro.util.ChainCache -import com.cra.figaro.util.Cache +import com.cra.figaro.library.cache.PermanentCache /** * Algorithm that computes probability of evidence using forward sampling. @@ -38,21 +37,21 @@ abstract class ProbEvidenceSampler(override val universe: Universe, override val } - protected var chainCache: Cache = new ChainCache(universe) + val lw = new LikelihoodWeighter(universe, new PermanentCache(universe)) /* * To protect against underflow, the probabilities are computed in log-space. */ protected def doSample(): Unit = { - Forward(universe, chainCache) - - //Some values in log constraints may be negative infinity. - val weight = universe.constrainedElements.map(_.constraintValue).sum - - val satisfied = universe.conditionedElements forall (_.conditionSatisfied) - + totalWeight += 1 - if (satisfied) successWeight = logSum(successWeight, weight) + + try { + val weight = lw.computeWeight(universe.activeElements) + successWeight = logSum(successWeight, weight) + } catch { + case Importance.Reject => () + } universe.clearTemporaries() // avoid memory leaks } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/WeightedSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/WeightedSampler.scala index 0465850f..af6bacaa 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/WeightedSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/WeightedSampler.scala @@ -67,8 +67,7 @@ abstract class WeightedSampler(override val universe: Universe, targets: Element } protected def doSample(): Unit = { - val s = sample() - universe.clearTemporaries() + val s = sample() totalWeight = logSum(s._1, totalWeight) allWeightsSeen foreach (updateWeightSeenForTarget(s, _)) } diff --git a/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala b/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala deleted file mode 100644 index baf70cbf..00000000 --- a/Figaro/src/main/scala/com/cra/figaro/experimental/sampling/LikelihoodWeighter.scala +++ /dev/null @@ -1,133 +0,0 @@ -package com.cra.figaro.experimental.sampling - -import com.cra.figaro.language._ -import scala.annotation.tailrec -import com.cra.figaro.algorithm.sampling.Importance - -class LikelihoodWeighter(universe: Universe) { - - def computeWeight(elementsToVisit: Set[Element[_]]): Double = { - traverse(List(), elementsToVisit, 0.0, Set(), scala.collection.mutable.Map[Dist[_, _], Int]()) - } - - @tailrec - private final def traverse(currentStack: List[(Element[_], Option[_])], - elementsToVisit: Set[Element[_]], - currentWeight: Double, - visited: Set[Element[_]], DistMap: scala.collection.mutable.Map[Dist[_, _], Int]): Double = { - - // If everything is empty, just return the weight - if (elementsToVisit.isEmpty && currentStack.isEmpty) { - currentWeight - } - // If the current stack is empty, we are free to choose any element to traverse. Pick the head of the set - else if (currentStack.isEmpty) { - traverse(List((elementsToVisit.head, getObservation(elementsToVisit.head, None))), elementsToVisit.tail, currentWeight, visited, DistMap) - } - // If the head of the stack has already been visited or in another universe, we don't need to do anything, go to the next element - else if (visited.contains(currentStack.head._1) || currentStack.head._1.universe != universe) { - traverse(currentStack.tail, elementsToVisit, currentWeight, visited, DistMap) - } - // Otherwise, we need to process the top element on the stack - else { - val (currElem, currObs) = currentStack.head - - currElem match { - case d: Dist[_, _] => - val parents = d match { - case dc: CompoundDist[_] => dc.probs.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None))) - case _ => List() - } - val rand = d.generateRandomness() - val index = d.selectIndex(rand) - val resultElement = d.outcomeArray(index) - val nextHead = List((resultElement, getObservation(resultElement, currObs)), (currElem, None)) - - if (parents.nonEmpty) { - traverse(parents ::: currentStack, elementsToVisit, currentWeight, visited, DistMap) - } else if (visited.contains(resultElement) && currObs.nonEmpty) { - traverse(nextHead ::: currentStack.tail, elementsToVisit, undoWeight(currentWeight, resultElement), visited - resultElement, DistMap += (d -> index)) - } else if (!visited.contains(resultElement)) { - traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited, DistMap += (d -> index)) - } else { - d.value = if (DistMap.contains(d)) d.finishGeneration(DistMap(d)) else d.finishGeneration(index) - DistMap -= d - val nextWeight = computeNextWeight(currentWeight, currElem, currObs) - traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) - } - case c: Chain[_, _] => - if (!visited.contains(c.parent)) { - traverse((c.parent, getObservation(c.parent, None)) +: currentStack, elementsToVisit, currentWeight, visited, DistMap) - } else { - val next = c.get(c.parent.value) - val nextHead = List((next, getObservation(next, currObs)), (currElem, None)) - if (visited.contains(next) && currObs.nonEmpty) { - // we did this in the wrong order, and have to repropagate the result for likelihood weighting - traverse(nextHead ::: currentStack.tail, elementsToVisit, undoWeight(currentWeight, next), visited - next, DistMap) - } else if (!visited.contains(next)) { - traverse(nextHead ::: currentStack.tail, elementsToVisit, currentWeight, visited, DistMap) - } else { - c.value = next.value - val nextWeight = computeNextWeight(currentWeight, currElem, currObs) - traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) - } - } - case _ => - val args = (currElem.args ::: currElem.elementsIAmContingentOn.toList) - // Find all the arguments of the element that have not been visited - val remainingArgs = args.filterNot(visited.contains(_)).map(e => (e, getObservation(e, None))) - // if there are args unvisited, push those args to the top of the stack - if (remainingArgs.nonEmpty) { - traverse(remainingArgs ::: currentStack, elementsToVisit, currentWeight, visited, DistMap) - } else { - // else, we can now process this element and move on to the next item - currElem.randomness = currElem.generateRandomness() - currElem.value = currElem.generateValue(currElem.randomness) - val nextWeight = computeNextWeight(currentWeight, currElem, currObs) - traverse(currentStack.tail, elementsToVisit - currElem, nextWeight, visited + currElem, DistMap) - } - } - - } - } - - def getObservation(element: Element[_], observation: Option[_]) = { - (observation, element.observation) match { - case (None, None) => None - case (Some(obs), None) => Some(obs) - case (None, Some(obs)) => Some(obs) - case (Some(obs1), Some(obs2)) if obs1 == obs2 => Some(obs1) - case _ => throw Importance.Reject // incompatible observations - } - } - - def computeNextWeight(currentWeight: Double, element: Element[_], obs: Option[_]): Double = { - val nextWeight = if (obs.isEmpty) { - if (!element.condition(element.value)) throw Importance.Reject - currentWeight - } else { - element match { - case f: CompoundFlip => { - element.value = obs.get.asInstanceOf[element.Value] - if (obs.get.asInstanceOf[Boolean]) currentWeight + math.log(f.prob.value) - else currentWeight + math.log(1 - f.prob.value) - } - case e: HasDensity[_] => { - element.value = obs.get.asInstanceOf[element.Value] - val density = element.asInstanceOf[HasDensity[element.Value]].density(obs.asInstanceOf[Option[element.Value]].get) - currentWeight + math.log(density) - } - case _ => { - if (!element.condition(element.value)) throw Importance.Reject - currentWeight - } - } - } - nextWeight + element.constraint(element.value) - } - - def undoWeight(weight: Double, elem: Element[_]) = weight - computeNextWeight(0.0, elem, elem.observation) - -} - - diff --git a/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala new file mode 100644 index 00000000..3a5afc22 --- /dev/null +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala @@ -0,0 +1,40 @@ +package com.cra.figaro.library.cache + +import com.cra.figaro.language._ +import scala.collection.generic.Shrinkable + +/** + * Abstract class to manage caching of element generation for a universe. This class can be used + * by algorithms to manage caching of chains. + */ +abstract class Cache(universe: Universe) extends Shrinkable[Element[_]] { + + /** + * Return the next element from the generative process defined by element. If no process + * is found, return None + */ + def apply[T](element: Element[T]): Option[Element[T]] + + universe.register(this) + + /** + * Clear any caching + */ + def clear(): Unit + +} + +/** A Cache class which performs no caching */ +class NoCache(universe: Universe) extends Cache(universe) { + def apply[T](element: Element[T]): Option[Element[T]] = { + element match { + case c: Chain[_,T] => Some(c.get(c.parent.value)) + case _ => None + + } + } + def clear() = {} + def -=(element: Element[_]) = this +} + + diff --git a/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala similarity index 85% rename from Figaro/src/main/scala/com/cra/figaro/util/Cache.scala rename to Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala index 3d1457ff..3036a8c4 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/Cache.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala @@ -1,39 +1,11 @@ -package com.cra.figaro.util +package com.cra.figaro.library.cache import com.cra.figaro.language._ import scala.collection.mutable.Map import scala.collection.mutable.Set -import scala.collection.generic.Shrinkable /** - * Abstract class to manage caching of element generation for a universe. This class can be used - * by algorithms to manage caching of chains. - */ -abstract class Cache(universe: Universe) extends Shrinkable[Element[_]] { - - /** - * Return the next element from the generative process defined by element. If no process - * is found, return None - */ - def apply[T](element: Element[T]): Option[Element[T]] - - universe.register(this) - - /** - * Clear any caching - */ - def clear(): Unit -} - -/** A Cache class which performs no caching */ -class NoCache(universe: Universe) extends Cache(universe) { - def apply[T](element: Element[T]): Option[Element[T]] = None - def clear() = {} - def -=(element: Element[_]) = this -} - -/** - * A class which implements caching for caching and non-caching chains. + * A class which implements caching for caching and non-caching chains, specifically designed for MH * * For caching chains, the result of the Chain's function is cached for each value of the parent element * that is queried. This cache is infinitely large. @@ -44,13 +16,13 @@ class NoCache(universe: Universe) extends Cache(universe) { * MH; if a proposal is rejected, we want to switch a chain back to where it was without much overhead. * */ -class ChainCache(universe: Universe) extends Cache(universe) { +class MHCache(universe: Universe) extends Cache(universe) { /* Caching chain cache that maps from an element to a map of parent values and resulting elements */ private[figaro] val ccCache: Map[Element[_], Map[Any, Element[_]]] = Map() /* The inverted cache. This maps from result elements back to the chain that uses them. This is needed - * to properly clean up deactivated elements + * to properly clean up deactivated elements */ private[figaro] val ccInvertedCache: Map[Element[_], Map[Element[_], Any]] = Map() @@ -58,7 +30,7 @@ class ChainCache(universe: Universe) extends Cache(universe) { * The non-caching chain "cache". This is a map from elements to a list of: * (parent value, result element, Set of elements created in the context of the parent value) * The Set is needed since once a parent value falls off the stack, we have to clear all the elements - * created in the context of that parent value or else we will have memory leaks + * created in the context of that parent value or else we will have memory leaks */ private[figaro] val nccCache: Map[Element[_], List[(Any, Element[_], Set[Element[_]])]] = Map() @@ -81,7 +53,7 @@ class ChainCache(universe: Universe) extends Cache(universe) { /* * Retrieves an element from the caching chain cache, or inserts a new one if none is found for - * the value of this element + * the value of this element */ private def doCachingChain[U, T](c: CachingChain[U, T]): Option[Element[T]] = { @@ -106,7 +78,7 @@ class ChainCache(universe: Universe) extends Cache(universe) { * MH. When a proposal is made, the chain may change its value. In such a case, we don't want to lose * the current result element in case the proposal is rejected, so it is moved to the back of the stack. * If the proposal is reject, the chain is regenerated and the old element is restored to the top of the stack. - * + * */ private def doNonCachingChain[U, T](c: NonCachingChain[U, T]): Option[Element[T]] = { val nccElems = nccCache.getOrElse(c, List()) @@ -145,7 +117,7 @@ class ChainCache(universe: Universe) extends Cache(universe) { } /** - * Removes an element from the cache. This is needed to properly clean up elements as they are deactivated. + * Removes an element from the cache. This is needed to properly clean up elements as they are deactivated. */ def -=(element: Element[_]) = { ccCache -= element @@ -157,7 +129,7 @@ class ChainCache(universe: Universe) extends Cache(universe) { } /** - * Clears the cache of all stored elements. + * Clears the cache of all stored elements. */ def clear() = { ccCache.clear() @@ -165,5 +137,4 @@ class ChainCache(universe: Universe) extends Cache(universe) { nccCache.clear() universe.deregister(this) } -} - +} \ No newline at end of file diff --git a/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala new file mode 100644 index 00000000..4487322f --- /dev/null +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala @@ -0,0 +1,81 @@ +package com.cra.figaro.library.cache + +import com.cra.figaro.language._ +import scala.collection.mutable.Map + +/** + * A class which only caches permanent result elements in chain. This class does not cache any non-caching + * chain result elements. Since this class does not implement any element cleanup operations, it is + * best used in an algorithm that clears temporary elements periodically. + * + * + */ +class PermanentCache(universe: Universe) extends Cache(universe) { + + /* Caching chain cache that maps from an element to a map of parent values and resulting elements */ + private[figaro] val ccCache: Map[Element[_], Map[Any, Element[_]]] = Map() + + /* The inverted cache. This maps from result elements back to the chain that uses them. This is needed + * to properly clean up deactivated elements + */ + private[figaro] val ccInvertedCache: Map[Element[_], Map[Element[_], Any]] = Map() + + /** + * Retrieve any cached element generated from the current value of the supplied element. Returns None if + * the element does not generate another element. + * + */ + def apply[T](element: Element[T]): Option[Element[T]] = { + element match { + case c: CachingChain[_, T] => { + doCachingChain(c) + } + case c: NonCachingChain[_, T] => { + Some(c.get(c.parent.value)) + } + case _ => None + } + } + + /* + * Retrieves an element from the caching chain cache, or inserts a new one if none is found for + * the value of this element + */ + private def doCachingChain[U, T](c: Chain[U, T]): Option[Element[T]] = { + + val cachedElems = ccCache.getOrElseUpdate(c, Map()) + val cachedValue = cachedElems.get(c.parent.value) + if (!cachedValue.isEmpty) cachedValue.asInstanceOf[Option[Element[T]]] + else { + // If the value of the element is not found in the cache, generate a new element by calling the chain, + // add it to the cache -only if the result is a permanent element- + val result = c.get(c.parent.value) + if (!result.isTemporary) { + cachedElems += (c.parent.value -> result) + val invertedElems = ccInvertedCache.getOrElseUpdate(result, Map()) + invertedElems += (c -> c.parent.value) + } + Some(result) + } + } + + /** + * Removes an element from the cache. This is needed to properly clean up elements as they are deactivated. + */ + def -=(element: Element[_]) = { + ccCache -= element + val invertValue = ccInvertedCache.get(element) + if (invertValue.nonEmpty) invertValue.get.foreach(e => ccCache(e._1) -= e._2) + ccInvertedCache -= element + this + } + + /** + * Clears the cache of all stored elements. + */ + def clear() = { + ccCache.clear() + ccInvertedCache.clear() + universe.deregister(this) + } +} \ No newline at end of file diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala index 65e87d3b..afa26d8f 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/ImportanceTest.scala @@ -31,19 +31,20 @@ import com.cra.figaro.test.tags.NonDeterministic import scala.language.reflectiveCalls import org.scalatest.Matchers import org.scalatest.{ PrivateMethodTester, WordSpec } +import scala.collection.mutable.Set + class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { "Sampling a value of a single element" should { "reject sampling process if condition violated" in { Universe.createNew() - val target = Flip(0.7) + val target = Flip(1.0) target.observe(false) val numTrials = 100000 val tolerance = 0.01 val imp = Importance(target) - val state = Importance.State() - an[RuntimeException] should be thrownBy { imp.sampleOne(state, target, Some(true)) } + an[RuntimeException] should be thrownBy { imp.lw.traverse(List((target, None, None)), List(), 0.0, Set()) } } @@ -54,9 +55,8 @@ class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { val numTrials = 100000 val tolerance = 0.01 val imp = Importance(target) - val state = Importance.State() - val value = imp.sampleOne(state, target, Some(false)) - value should equal(false) + imp.lw.traverse(List((target, Some(false), None)), List(), 0.0, Set()) + target.value should equal(false) } "for a Constant return the constant with probability 1" in { @@ -211,8 +211,7 @@ class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { Apply(B, (b: List[List[Boolean]]) => b.head) }) val alg = Importance(1, c) - val state = Importance.State() - alg.sampleOne(state, c, None) + alg.lw.computeWeight(List(c)) c.value.asInstanceOf[List[Boolean]].head should be(true || false) } } @@ -637,9 +636,8 @@ class ImportanceTest extends WordSpec with Matchers with PrivateMethodTester { def attempt(): (Double, T) = { try { - val state = Importance.State() - val value = imp.sampleOne(state, target, None) - (state.weight, value.asInstanceOf[T]) + val weight = imp.lw.computeWeight(List(target)) + (weight, target.value.asInstanceOf[T]) } catch { case Importance.Reject => attempt() } diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala new file mode 100644 index 00000000..6faf3dc6 --- /dev/null +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala @@ -0,0 +1,109 @@ +/* + * ImportanceTest.scala + * Importance sampling tests. + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ + +package com.cra.figaro.test.algorithm.sampling + +import org.scalatest._ +import org.scalatest.Matchers +import com.cra.figaro.algorithm._ +import com.cra.figaro.algorithm.sampling.Importance.Reject +import com.cra.figaro.algorithm.sampling._ +import com.cra.figaro.language._ +import com.cra.figaro.library.atomic.continuous._ +import com.cra.figaro.library.atomic._ +import com.cra.figaro.library.atomic.discrete.Binomial +import com.cra.figaro.library.compound._ +import com.cra.figaro.test._ +import com.cra.figaro.util.logSum +import JSci.maths.statistics._ +import com.cra.figaro.test.tags.Performance +import com.cra.figaro.test.tags.NonDeterministic +import scala.language.reflectiveCalls +import org.scalatest.Matchers +import org.scalatest.{ PrivateMethodTester, WordSpec } +import com.cra.figaro.algorithm.sampling.LikelihoodWeighter +import com.cra.figaro.library.cache.NoCache +class LikelihoodWeighterTest extends WordSpec with Matchers with PrivateMethodTester { + + // Note, many test cases are covered under the Importance sampling tests + val normalizer = 1.0 / math.sqrt(2.0 * math.Pi) + + "Running likelihood weighting" should { + "Correctly build dependencies" in { + Universe.createNew() + val result1 = Flip(0.7) + val result2 = Flip(0.5) + + val c1 = If(Flip(0.1), Flip(0.2), result1) + val c2 = If(Flip(0.1), Flip(0.2), result2) + val c3 = If(Flip(0.5), c1, c2) + + val lw = new LikelihoodWeighter(Universe.universe, new NoCache(Universe.universe)) + for { _ <- 0 until 500 } lw.traverse(List(), Universe.universe.activeElements, 0.0, scala.collection.mutable.Set()) + lw.dependencies.contains(result1) should equal(true) + lw.dependencies.contains(result2) should equal(true) + } + + "Undo the weight if a result element was sampled before its parent when the parent has an observation" in { + Universe.createNew() + val result1 = Normal(0, 1) + val result2 = Normal(1, 1) + val c = Chain(Constant(true), (b: Boolean) => if (b) result1 else result2) + result1.addConstraint((d: Double) => Normal.density(1.0, 1.0, 1.0)(d)) + c.observe(0.0) + + val lw = new LikelihoodWeighter(Universe.universe, new NoCache(Universe.universe)) + val weight = lw.traverse(List(), List(result1, result2, c), 0.0, scala.collection.mutable.Set()) + val correct = result1.density(0.0) * Normal.density(1.0, 1, 1)(0) + math.exp(weight) should be(correct +- .00001) + } + + "Not have invalid states for a chain if sampled in the wrong order" in { + Universe.createNew() + val result1 = Normal(0, 1) + val result2 = Normal(1, 1) + val c = Chain(Constant(true), (b: Boolean) => if (b) result1 else result2) + c.observe(0.0) + val f = result1 ++ Constant(0.0) + val lw = new LikelihoodWeighter(Universe.universe, new NoCache(Universe.universe)) + val weight = lw.traverse(List(), List(f, result1, result2, c), 0.0, scala.collection.mutable.Set()) + f.value should equal(0.0) + } + + "Not have invalid states for a dist if sampled in the wrong order" in { + Universe.createNew() + val result1 = Normal(0, 1) + val result2 = Normal(1, 1) + val c = Dist(1.0 -> result1, 0.0 -> result2) + c.observe(0.0) + val f = result1 ++ Constant(0.0) + val lw = new LikelihoodWeighter(Universe.universe, new NoCache(Universe.universe)) + val weight = lw.traverse(List(), List(f, result1, result2, c), 0.0, scala.collection.mutable.Set()) + f.value should equal(0.0) + } + + "Not overflow the stack" in { + def next(count: Int): Element[Boolean] = { + if (count == 0) Constant(true) + else Chain(Flip(0.5), (b: Boolean) => next(count - 1)) + } + Universe.createNew() + val start = next(2000) + val lw = new LikelihoodWeighter(Universe.universe, new NoCache(Universe.universe)) + lw.computeWeight(List(start)) + //an[StackOverflowError] should not be thrownBy { lw.computeWeight(List(start)) } + } + + } + +} diff --git a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala index 819736ca..f8eb1b86 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/example/OpenUniverseTest.scala @@ -58,8 +58,7 @@ class OpenUniverseTest extends WordSpec with Matchers { (0.25, () => ProposalScheme(numSources)), (0.25, () => ProposalScheme(sources.items(random.nextInt(numSources.value)))), (0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).sourceNum)), - (0.25, () => ProposalScheme.default)) - //(0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).position))) + (0.25, () => ProposalScheme(samples(random.nextInt(numSamples)).position))) sample1.position.addCondition((y: Double) => y >= 0.5 && y < 0.8) sample2.position.addCondition((y: Double) => y >= 0.5 && y < 0.8) diff --git a/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala index 28335975..1ed183c4 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala @@ -15,16 +15,21 @@ package com.cra.figaro.test.util import org.scalatest.Matchers import org.scalatest.WordSpec -import com.cra.figaro.util._ -import com.cra.figaro.language._ + +import com.cra.figaro.language.CachingChain +import com.cra.figaro.language.Chain +import com.cra.figaro.language.Constant +import com.cra.figaro.language.Flip +import com.cra.figaro.language.Universe import com.cra.figaro.library.atomic.continuous.Uniform +import com.cra.figaro.library.cache.MHCache import com.cra.figaro.library.compound.If class CacheTest extends WordSpec with Matchers { - "A chain cache" should { + "A MH cache" should { "correctly retrieve cache elements for caching chains" in { val u = Universe.createNew() - val cc = new ChainCache(u) + val cc = new MHCache(u) var sum = 0 def fn(b: Boolean) = { sum += 1 @@ -40,7 +45,7 @@ class CacheTest extends WordSpec with Matchers { "keep the stack at maximum of two for non-caching chains" in { val u = Universe.createNew() - val cc = new ChainCache(u) + val cc = new MHCache(u) val f = Uniform(0.0, 1.0) val c = Chain(f, (d: Double) => Constant(d)) for { _ <- 0 until 10 } { @@ -52,7 +57,7 @@ class CacheTest extends WordSpec with Matchers { "remove deactivated elements from the cache" in { val u = Universe.createNew() - val cc = new ChainCache(u) + val cc = new MHCache(u) val a1 = Flip(0.1) val a2 = Flip(0.2) val s = Flip(0.5) @@ -66,7 +71,7 @@ class CacheTest extends WordSpec with Matchers { "correctly clear the context of elements removed from the stack" in { val u = Universe.createNew() - val cc = new ChainCache(u) + val cc = new MHCache(u) def fn(d: Double) = { Flip(d); Flip(d); Flip(d) } @@ -84,7 +89,7 @@ class CacheTest extends WordSpec with Matchers { "correctly clear the caches when clearing temporaries" in { val u = Universe.createNew() - val cc = new ChainCache(u) + val cc = new MHCache(u) def fn(d: Double) = { Flip(d); Flip(d); Flip(d) } From 2614f8b3575cda432bd79df24d02466f9ee8cd60 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Wed, 1 Jul 2015 15:11:42 -0400 Subject: [PATCH 05/18] Add scala-swing dependency --- project/Build.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/Build.scala b/project/Build.scala index e359475b..ef430b07 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -85,8 +85,9 @@ object FigaroBuild extends Build { "com.typesafe.akka" %% "akka-actor" % "2.3.8", "org.scalanlp" %% "breeze" % "0.10", "io.argonaut" %% "argonaut" % "6.0.4", - "com.storm-enroute" %% "scalameter" % "0.6" % "provided", "org.prefuse" % "prefuse" % "beta-20071021", + "org.scala-lang.modules" %% "scala-swing" % "1.0.1", + "com.storm-enroute" %% "scalameter" % "0.6" % "provided", "org.scalatest" %% "scalatest" % "2.2.4" % "provided, test" )) // Copy all managed dependencies to \lib_managed directory From 4b5f29cd2d119a31d2bb84fdf105d06b8ba2e334 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Wed, 1 Jul 2015 15:16:17 -0400 Subject: [PATCH 06/18] Add scala-swing dependency --- project/Build.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/Build.scala b/project/Build.scala index e359475b..47378df6 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -85,8 +85,9 @@ object FigaroBuild extends Build { "com.typesafe.akka" %% "akka-actor" % "2.3.8", "org.scalanlp" %% "breeze" % "0.10", "io.argonaut" %% "argonaut" % "6.0.4", - "com.storm-enroute" %% "scalameter" % "0.6" % "provided", "org.prefuse" % "prefuse" % "beta-20071021", + "org.scala-lang.modules" %% "scala-swing" % "1.0.1" + "com.storm-enroute" %% "scalameter" % "0.6" % "provided", "org.scalatest" %% "scalatest" % "2.2.4" % "provided, test" )) // Copy all managed dependencies to \lib_managed directory From bf9879e01a60a9aee68a585e14c822735ea08768 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Thu, 2 Jul 2015 10:56:53 -0400 Subject: [PATCH 07/18] Updated for v3.3 --- Figaro/META-INF/MANIFEST.MF | 2 +- Figaro/figaro_build.properties | 2 +- FigaroExamples/META-INF/MANIFEST.MF | 4 ++-- project/Build.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Figaro/META-INF/MANIFEST.MF b/Figaro/META-INF/MANIFEST.MF index 3550236f..a7e60963 100644 --- a/Figaro/META-INF/MANIFEST.MF +++ b/Figaro/META-INF/MANIFEST.MF @@ -2,7 +2,7 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: Figaro Bundle-SymbolicName: com.cra.figaro -Bundle-Version: 3.2.1 +Bundle-Version: 3.3.0 Export-Package: com.cra.figaro.algorithm, com.cra.figaro.algorithm.decision, com.cra.figaro.algorithm.decision.index, diff --git a/Figaro/figaro_build.properties b/Figaro/figaro_build.properties index 811a88b7..3ae169ff 100644 --- a/Figaro/figaro_build.properties +++ b/Figaro/figaro_build.properties @@ -1 +1 @@ -version=3.2.1.0 +version=3.3.0.0 diff --git a/FigaroExamples/META-INF/MANIFEST.MF b/FigaroExamples/META-INF/MANIFEST.MF index 0f4d08d9..7b360a6b 100644 --- a/FigaroExamples/META-INF/MANIFEST.MF +++ b/FigaroExamples/META-INF/MANIFEST.MF @@ -2,8 +2,8 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: FigaroExamples Bundle-SymbolicName: com.cra.figaro.examples -Bundle-Version: 3.2.1 -Require-Bundle: com.cra.figaro;bundle-version="3.2.1", +Bundle-Version: 3.3.0 +Require-Bundle: com.cra.figaro;bundle-version="3.3.0", org.scala-lang.scala-library Bundle-Vendor: Charles River Analytics Bundle-RequiredExecutionEnvironment: JavaSE-1.6 diff --git a/project/Build.scala b/project/Build.scala index 47378df6..b4226df3 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -24,7 +24,7 @@ object FigaroBuild extends Build { override val settings = super.settings ++ Seq( organization := "com.cra.figaro", description := "Figaro: a language for probablistic programming", - version := "3.2.1.0", + version := "3.3.0.0", scalaVersion := "2.11.6", crossPaths := true, publishMavenStyle := true, From 3191b450d1a7c419fa2c05d9623e964f845ee0a1 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Thu, 2 Jul 2015 10:58:57 -0400 Subject: [PATCH 08/18] Updated for v4.0 --- Figaro/META-INF/MANIFEST.MF | 2 +- Figaro/figaro_build.properties | 2 +- FigaroExamples/META-INF/MANIFEST.MF | 4 ++-- project/Build.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Figaro/META-INF/MANIFEST.MF b/Figaro/META-INF/MANIFEST.MF index 3550236f..87dbb207 100644 --- a/Figaro/META-INF/MANIFEST.MF +++ b/Figaro/META-INF/MANIFEST.MF @@ -2,7 +2,7 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: Figaro Bundle-SymbolicName: com.cra.figaro -Bundle-Version: 3.2.1 +Bundle-Version: 4.0.0 Export-Package: com.cra.figaro.algorithm, com.cra.figaro.algorithm.decision, com.cra.figaro.algorithm.decision.index, diff --git a/Figaro/figaro_build.properties b/Figaro/figaro_build.properties index 811a88b7..9052a6eb 100644 --- a/Figaro/figaro_build.properties +++ b/Figaro/figaro_build.properties @@ -1 +1 @@ -version=3.2.1.0 +version=4.0.0.0 diff --git a/FigaroExamples/META-INF/MANIFEST.MF b/FigaroExamples/META-INF/MANIFEST.MF index 0f4d08d9..07c9c94c 100644 --- a/FigaroExamples/META-INF/MANIFEST.MF +++ b/FigaroExamples/META-INF/MANIFEST.MF @@ -2,8 +2,8 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: FigaroExamples Bundle-SymbolicName: com.cra.figaro.examples -Bundle-Version: 3.2.1 -Require-Bundle: com.cra.figaro;bundle-version="3.2.1", +Bundle-Version: 4.0.0 +Require-Bundle: com.cra.figaro;bundle-version="4.0.0", org.scala-lang.scala-library Bundle-Vendor: Charles River Analytics Bundle-RequiredExecutionEnvironment: JavaSE-1.6 diff --git a/project/Build.scala b/project/Build.scala index ef430b07..f3bd227d 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -24,7 +24,7 @@ object FigaroBuild extends Build { override val settings = super.settings ++ Seq( organization := "com.cra.figaro", description := "Figaro: a language for probablistic programming", - version := "3.2.1.0", + version := "4.0.0.0", scalaVersion := "2.11.6", crossPaths := true, publishMavenStyle := true, From 003c2b12a38eb4e64532d69bf8edd78e074b5407 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Thu, 2 Jul 2015 11:05:05 -0400 Subject: [PATCH 09/18] Missing comma --- project/Build.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/Build.scala b/project/Build.scala index b4226df3..234eabe0 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -86,7 +86,7 @@ object FigaroBuild extends Build { "org.scalanlp" %% "breeze" % "0.10", "io.argonaut" %% "argonaut" % "6.0.4", "org.prefuse" % "prefuse" % "beta-20071021", - "org.scala-lang.modules" %% "scala-swing" % "1.0.1" + "org.scala-lang.modules" %% "scala-swing" % "1.0.1", "com.storm-enroute" %% "scalameter" % "0.6" % "provided", "org.scalatest" %% "scalatest" % "2.2.4" % "provided, test" )) From fc0465c1177fcc6cfc8adcd9e254363e73b65665 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Wed, 8 Jul 2015 13:37:58 -0400 Subject: [PATCH 10/18] ScalaDoc fixes --- .../figaro/util/visualization/ResultsGUI.scala | 3 +-- .../distribution/DistributionRenderer.scala | 3 ++- .../visualization/reduction/DataReduction.scala | 15 ++++++++++++--- .../util/visualization/results/ResultsData.scala | 3 +-- .../util/visualization/results/ResultsTable.scala | 3 +-- .../util/visualization/results/ResultsView.scala | 3 +-- .../sampling/LikelihoodWeighterTest.scala | 2 +- .../com/cra/figaro/test/util/CacheTest.scala | 2 +- 8 files changed, 20 insertions(+), 14 deletions(-) diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/ResultsGUI.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/ResultsGUI.scala index 36e132cd..f9982a8c 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/ResultsGUI.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/ResultsGUI.scala @@ -1,7 +1,6 @@ /* * ResultsGUI.scala - * The main controller for visualizations. - * Coordinates data input and display as well as user interaction with displays. + * The main controller for visualizations. Coordinates data input and display as well as user interaction with displays. * * Created By: Glenn Takata (gtakata@cra.com) * Creation Date: Mar 16, 2015 diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/distribution/DistributionRenderer.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/distribution/DistributionRenderer.scala index c056ea92..72d1841a 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/distribution/DistributionRenderer.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/distribution/DistributionRenderer.scala @@ -9,7 +9,8 @@ * See http://www.cra.com or email figaro@cra.com for information. * * See http://www.github.com/p2t2/figaro for a copy of the software license. - */package com.cra.figaro.util.visualization.distribution + */ +package com.cra.figaro.util.visualization.distribution import java.awt.Graphics2D import java.awt.Shape diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/reduction/DataReduction.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/reduction/DataReduction.scala index 45e6a455..2e6fe72c 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/reduction/DataReduction.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/reduction/DataReduction.scala @@ -1,10 +1,19 @@ +/* + * DataReduction.scala + * Setup and display distributions based on continuous element data + * + * Created By: Glenn Takata (gtakata@cra.com) + * Creation Date: Jul 6, 2015 + * + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ package com.cra.figaro.util.visualization.reduction import scala.collection._ -/** - * @author gtakata - */ object DataReduction { def binToDistribution(data: List[(Double, Double)]): List[(Double, Double)] = { if (data.size > 50) { diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsData.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsData.scala index 6832c34c..4f11c5a4 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsData.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsData.scala @@ -1,7 +1,6 @@ /* * ResultsData.scala - * Trait and classes representing data input by the user. - * Includes discrete (distribution List) and continuous (element) + * Trait and classes representing data input by the user. Includes discrete (distribution List) and continuous (element) * * Created By: Glenn Takata (gtakata@cra.com) * Creation Date: Apr 9, 2015 diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsTable.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsTable.scala index 37c12a07..735924b3 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsTable.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsTable.scala @@ -1,7 +1,6 @@ /* * ResultsTable.scala - * Visual element for a table to display user inputs. - * Includes discrete (distribution List) and continuous (element) + * Visual element for a table to display user inputs. Includes discrete (distribution List) and continuous (element) * * Created By: Glenn Takata (gtakata@cra.com) * Creation Date: Apr 9, 2015 diff --git a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsView.scala b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsView.scala index 7466026a..fce1ef4b 100644 --- a/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsView.scala +++ b/Figaro/src/main/scala/com/cra/figaro/util/visualization/results/ResultsView.scala @@ -1,7 +1,6 @@ /* * ResultsView.scala - * A visual component to display a table of user data. - * Includes discrete (distribution List) and continuous (element) + * A visual component to display a table of user data. Includes discrete (distribution List) and continuous (element) * * Created By: Glenn Takata (gtakata@cra.com) * Creation Date: Mar 16, 2015 diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala index 6faf3dc6..de451562 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/sampling/LikelihoodWeighterTest.scala @@ -1,5 +1,5 @@ /* - * ImportanceTest.scala + * LikelihoodWeighterTest.scala * Importance sampling tests. * * Created By: Avi Pfeffer (apfeffer@cra.com) diff --git a/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala index 1ed183c4..1c0a1041 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/util/CacheTest.scala @@ -1,5 +1,5 @@ /* - * MultiSetTest.scala + * CacheTest.scala * Needs description * * Created By: Avi Pfeffer (apfeffer@cra.com) From 866856bd9e09c596d1dfdd34c5c4d80aa04d32d7 Mon Sep 17 00:00:00 2001 From: lfkellogg Date: Wed, 8 Jul 2015 13:41:41 -0400 Subject: [PATCH 11/18] Add file headers. --- .../figaro/algorithm/filtering/ParFiltering.scala | 10 +++++----- .../algorithm/filtering/ParParticleFilter.scala | 13 +++++++++++++ .../algorithm/sampling/ProbQuerySampler.scala | 13 +++++++++++++ .../algorithm/sampling/parallel/ParSampler.scala | 13 +++++++++++++ .../sampling/parallel/ParSamplingAlgorithm.scala | 6 +++--- .../algorithm/filtering/ParParticleFilterTest.scala | 10 +++++----- 6 files changed, 52 insertions(+), 13 deletions(-) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParFiltering.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParFiltering.scala index f64f29d1..abb3c692 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParFiltering.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParFiltering.scala @@ -1,11 +1,11 @@ /* - * Filtering.scala - * Filtering algorithms. + * ParFiltering.scala + * A parallel version of Filtering. * - * Created By: Avi Pfeffer (apfeffer@cra.com) - * Creation Date: Jan 1, 2009 + * Created By: Lee Kellogg (lkellogg@cra.com) + * Creation Date: Jun 2, 2015 * - * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. * See http://www.cra.com or email figaro@cra.com for information. * * See http://www.github.com/p2t2/figaro for a copy of the software license. diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala index 521e65ce..1bad2601 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/filtering/ParParticleFilter.scala @@ -1,3 +1,16 @@ +/* + * ParParticleFilter.scala + * A parallel one-time particle filter. + * + * Created By: Lee Kellogg (lkellogg@cra.com) + * Creation Date: Jun 2, 2015 + * + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ + package com.cra.figaro.algorithm.filtering import com.cra.figaro.language._ diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbQuerySampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbQuerySampler.scala index 5e005413..e50a8fae 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbQuerySampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ProbQuerySampler.scala @@ -1,3 +1,16 @@ +/* + * ProbQuerySampler.scala + * Sampling algorithms that use projected samples to compute conditional probabilities. + * + * Created By: Lee Kellogg (lkellog@cra.com) + * Creation Date: June 2, 2015 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ + package com.cra.figaro.algorithm.sampling import com.cra.figaro.algorithm.ProbQueryAlgorithm diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSampler.scala index f4f9553e..a5e89576 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSampler.scala @@ -1,3 +1,16 @@ +/* + * ParSampler.scala + * Parallel version of a sampling algorithm. + * + * Created By: Lee Kellogg (lkellogg@cra.com) + * Creation Date: Jun 2, 2015 + * + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ + package com.cra.figaro.algorithm.sampling.parallel import com.cra.figaro.language._ diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSamplingAlgorithm.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSamplingAlgorithm.scala index 05b6c155..980f31cd 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSamplingAlgorithm.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/parallel/ParSamplingAlgorithm.scala @@ -1,9 +1,9 @@ /* - * ParAlgorithm.scala - * Parallel algorithms. + * ParSamplingAlgorithm.scala + * Parallel sampling algorithms. * * Created By: Lee Kellogg (lkellog@cra.com) - * Creation Date: May 11, 2015 + * Creation Date: June 2, 2015 * * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. * See http://www.cra.com or email figaro@cra.com for information. diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/filtering/ParParticleFilterTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/filtering/ParParticleFilterTest.scala index e7e3ce1d..3cd6c653 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/filtering/ParParticleFilterTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/filtering/ParParticleFilterTest.scala @@ -1,11 +1,11 @@ /* - * ParticleFilterTest.scala - * Particle filter tests. + * ParParticleFilterTest.scala + * Parallel particle filter tests. * - * Created By: Avi Pfeffer (apfeffer@cra.com) - * Creation Date: Jan 1, 2009 + * Created By: Lee Kellogg (lkellogg@cra.com) + * Creation Date: Jun 2, 2015 * - * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. * See http://www.cra.com or email figaro@cra.com for information. * * See http://www.github.com/p2t2/figaro for a copy of the software license. From c46c89b570a368008bab89e7e90b383119c5f2cf Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Wed, 8 Jul 2015 14:07:28 -0400 Subject: [PATCH 12/18] More ScalaDoc fixes --- .../algorithm/sampling/LikelihoodWeighter.scala | 12 ++++++++++++ .../scala/com/cra/figaro/library/cache/Cache.scala | 12 ++++++++++++ .../scala/com/cra/figaro/library/cache/MHCache.scala | 12 ++++++++++++ .../cra/figaro/library/cache/PermanentCache.scala | 12 ++++++++++++ 4 files changed, 48 insertions(+) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala index dec31672..4fcd0c79 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala @@ -1,3 +1,15 @@ +/* + * LikelihoodWeighter.scala + * Likelihood weighting works by propagating observations through Dists and Chains to the variables they depend on. + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ package com.cra.figaro.algorithm.sampling import com.cra.figaro.language._ diff --git a/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala index 3a5afc22..e48b8728 100644 --- a/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/Cache.scala @@ -1,3 +1,15 @@ +/* + * Cache.scala + * Abstract class to manage caching of element generation for a universe. + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ package com.cra.figaro.library.cache import com.cra.figaro.language._ diff --git a/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala index 3036a8c4..2390c6ad 100644 --- a/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/MHCache.scala @@ -1,3 +1,15 @@ +/* + * MHCache.scala + * Implements caching for caching and non-caching chains, specifically designed for MH. + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ package com.cra.figaro.library.cache import com.cra.figaro.language._ diff --git a/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala b/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala index 4487322f..f4391c24 100644 --- a/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala +++ b/Figaro/src/main/scala/com/cra/figaro/library/cache/PermanentCache.scala @@ -1,3 +1,15 @@ +/* + * PermanentCache.scala + * Only caches permanent result elements in chain. + * + * Created By: Avi Pfeffer (apfeffer@cra.com) + * Creation Date: Jan 1, 2009 + * + * Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ package com.cra.figaro.library.cache import com.cra.figaro.language._ From 458c9a46545a1de10e1836dfc19db0dc76f9620d Mon Sep 17 00:00:00 2001 From: Glenn Takata Date: Thu, 9 Jul 2015 11:09:39 -0400 Subject: [PATCH 13/18] Use util.TailCall to elimination recursion in eliminateInOrder --- .../factored/VariableElimination.scala | 44 +++++++++++--- .../algorithm/factored/VERecursionTest.scala | 58 +++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala index ec030306..ccfc03ca 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala @@ -123,18 +123,48 @@ trait VariableElimination[T] extends FactoredAlgorithm[T] with OneTime { } } + // protected def eliminateInOrder( + // order: List[Variable[_]], + // factors: MultiSet[Factor[T]], + // map: FactorMap[T]): MultiSet[Factor[T]] = + // order match { + // case Nil => + // factors + // case first :: rest => + // eliminate(first, factors, map) + // eliminateInOrder(rest, factors, map) + // } + + import scala.util.control.TailCalls._ + + // Wraps the TailRec class and returns the result protected def eliminateInOrder( order: List[Variable[_]], factors: MultiSet[Factor[T]], - map: FactorMap[T]): MultiSet[Factor[T]] = + map: FactorMap[T]): MultiSet[Factor[T]] = { + + callEliminateInOrder(order, factors, map).result + } + + /* + * TailRec class turns a tail-recursive method into a while loop + * + * The result needs to be extracted explicitly + */ + private def callEliminateInOrder( + order: List[Variable[_]], + factors: MultiSet[Factor[T]], + map: FactorMap[T]): TailRec[MultiSet[Factor[T]]] = { + order match { case Nil => - factors + done(factors) case first :: rest => eliminate(first, factors, map) - eliminateInOrder(rest, factors, map) + tailcall(callEliminateInOrder(rest, factors, map)) } - + } + private[figaro] def ve(): Unit = { //expand() val (neededElements, _) = getNeededElements(starterElements, Int.MaxValue) @@ -197,8 +227,8 @@ class ProbQueryVariableElimination(override val universe: Universe, targets: Ele val showTiming: Boolean, val dependentUniverses: List[(Universe, List[NamedEvidence[_]])], val dependentAlgorithm: (Universe, List[NamedEvidence[_]]) => () => Double) - extends OneTimeProbQuery - with ProbabilisticVariableElimination { + extends OneTimeProbQuery + with ProbabilisticVariableElimination { val targetElements = targets.toList lazy val queryTargets = targets.toList @@ -236,7 +266,7 @@ class ProbQueryVariableElimination(override val universe: Universe, targets: Ele dist.toStream } - /** + /** * Computes the expectation of a given function for single target element. */ def computeExpectation[T](target: Element[T], function: T => Double): Double = { diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala new file mode 100644 index 00000000..33f37dd9 --- /dev/null +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala @@ -0,0 +1,58 @@ +/* + * VERecursionTest.scala + * Variable elimination test of tail recursion . + * + * Created By: Glenn Takata (gtakata@cra.com) + * Creation Date: Jul 7, 2015 + * + * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. + * See http://www.cra.com or email figaro@cra.com for information. + * + * See http://www.github.com/p2t2/figaro for a copy of the software license. + */ +package com.cra.figaro.test.algorithm.factored + +import org.scalatest.Matchers +import org.scalatest.{ WordSpec, PrivateMethodTester } +import scala.util.Random +import com.cra.figaro.language.{ Apply, Element, Flip, Universe } +import com.cra.figaro.library.atomic.continuous.{Uniform} +import com.cra.figaro.library.compound.{ If } +import com.cra.figaro.algorithm.factored.VariableElimination + +/** + * @author Glenn Takata (gtakata@cra.com) + * + */ +class VERecursionTest extends WordSpec with Matchers { + "Running VariableElimination" should { + "with a very wide model produce the correct result" in { + Universe.createNew() + var root = Flip(0.5) + + val rand = new Random(System.currentTimeMillis) + for (_ <- 0 until 10000) { + val v = If(root, Flip(0.5), Flip(0.5)) + if ( rand.nextBoolean) { + v.observe(true) + } + else { + v.observe(false) + } + } + test(root, (r: Boolean) => r == true, 0.50) + } + } + + def test[T](target: Element[T], predicate: T => Boolean, prob: Double) { + val tolerance = 0.00001 + val algorithm = VariableElimination(target) + algorithm.start() + + val dist = algorithm.distribution(target).toList + println("\nTarget distribution: " + dist.mkString(",") + "\n") + + algorithm.probability(target, predicate) should be(prob +- tolerance) + algorithm.kill() + } +} \ No newline at end of file From 011683ac5fad5f826d2dc5ae27dcad1441b170ed Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Fri, 10 Jul 2015 08:48:29 -0400 Subject: [PATCH 14/18] Fixes to VE recursion. Moved test into VETest, made shorter. --- .../factored/VariableElimination.scala | 45 +++----------- .../algorithm/factored/VERecursionTest.scala | 58 ------------------- .../test/algorithm/factored/VETest.scala | 15 +++++ 3 files changed, 23 insertions(+), 95 deletions(-) delete mode 100644 Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala index ccfc03ca..35e18b3a 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala @@ -123,48 +123,19 @@ trait VariableElimination[T] extends FactoredAlgorithm[T] with OneTime { } } - // protected def eliminateInOrder( - // order: List[Variable[_]], - // factors: MultiSet[Factor[T]], - // map: FactorMap[T]): MultiSet[Factor[T]] = - // order match { - // case Nil => - // factors - // case first :: rest => - // eliminate(first, factors, map) - // eliminateInOrder(rest, factors, map) - // } - - import scala.util.control.TailCalls._ - - // Wraps the TailRec class and returns the result - protected def eliminateInOrder( + @tailrec + protected final def eliminateInOrder( order: List[Variable[_]], factors: MultiSet[Factor[T]], - map: FactorMap[T]): MultiSet[Factor[T]] = { - - callEliminateInOrder(order, factors, map).result - } - - /* - * TailRec class turns a tail-recursive method into a while loop - * - * The result needs to be extracted explicitly - */ - private def callEliminateInOrder( - order: List[Variable[_]], - factors: MultiSet[Factor[T]], - map: FactorMap[T]): TailRec[MultiSet[Factor[T]]] = { - + map: FactorMap[T]): MultiSet[Factor[T]] = order match { case Nil => - done(factors) + factors case first :: rest => eliminate(first, factors, map) - tailcall(callEliminateInOrder(rest, factors, map)) + eliminateInOrder(rest, factors, map) } - } - + private[figaro] def ve(): Unit = { //expand() val (neededElements, _) = getNeededElements(starterElements, Int.MaxValue) @@ -227,8 +198,8 @@ class ProbQueryVariableElimination(override val universe: Universe, targets: Ele val showTiming: Boolean, val dependentUniverses: List[(Universe, List[NamedEvidence[_]])], val dependentAlgorithm: (Universe, List[NamedEvidence[_]]) => () => Double) - extends OneTimeProbQuery - with ProbabilisticVariableElimination { + extends OneTimeProbQuery + with ProbabilisticVariableElimination { val targetElements = targets.toList lazy val queryTargets = targets.toList diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala deleted file mode 100644 index 33f37dd9..00000000 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VERecursionTest.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * VERecursionTest.scala - * Variable elimination test of tail recursion . - * - * Created By: Glenn Takata (gtakata@cra.com) - * Creation Date: Jul 7, 2015 - * - * Copyright 2015 Avrom J. Pfeffer and Charles River Analytics, Inc. - * See http://www.cra.com or email figaro@cra.com for information. - * - * See http://www.github.com/p2t2/figaro for a copy of the software license. - */ -package com.cra.figaro.test.algorithm.factored - -import org.scalatest.Matchers -import org.scalatest.{ WordSpec, PrivateMethodTester } -import scala.util.Random -import com.cra.figaro.language.{ Apply, Element, Flip, Universe } -import com.cra.figaro.library.atomic.continuous.{Uniform} -import com.cra.figaro.library.compound.{ If } -import com.cra.figaro.algorithm.factored.VariableElimination - -/** - * @author Glenn Takata (gtakata@cra.com) - * - */ -class VERecursionTest extends WordSpec with Matchers { - "Running VariableElimination" should { - "with a very wide model produce the correct result" in { - Universe.createNew() - var root = Flip(0.5) - - val rand = new Random(System.currentTimeMillis) - for (_ <- 0 until 10000) { - val v = If(root, Flip(0.5), Flip(0.5)) - if ( rand.nextBoolean) { - v.observe(true) - } - else { - v.observe(false) - } - } - test(root, (r: Boolean) => r == true, 0.50) - } - } - - def test[T](target: Element[T], predicate: T => Boolean, prob: Double) { - val tolerance = 0.00001 - val algorithm = VariableElimination(target) - algorithm.start() - - val dist = algorithm.distribution(target).toList - println("\nTarget distribution: " + dist.mkString(",") + "\n") - - algorithm.probability(target, predicate) should be(prob +- tolerance) - algorithm.kill() - } -} \ No newline at end of file diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VETest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VETest.scala index f41ad157..4c4c9d18 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VETest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/VETest.scala @@ -382,6 +382,21 @@ class VETest extends WordSpec with Matchers { ve.kill } + "with a very wide model produce the correct result" in { + Universe.createNew() + var root = Flip(0.5) + + val rand = new scala.util.Random(System.currentTimeMillis) + for (_ <- 0 until 1000) { + val v = If(root, Flip(0.5), Flip(0.5)) + if (rand.nextBoolean) { + v.observe(true) + } else { + v.observe(false) + } + } + test(root, (r: Boolean) => r == true, 0.50) + } } "MPEVariableElimination" should { From 90508b0697930c6fae3a8b8433d5512e908d561e Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Fri, 10 Jul 2015 10:09:31 -0400 Subject: [PATCH 15/18] Reverted back to tail calls --- .../factored/VariableElimination.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala index 35e18b3a..9c13a5bf 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/factored/VariableElimination.scala @@ -21,6 +21,7 @@ import com.cra.figaro.util._ import annotation.tailrec import scala.collection.mutable.{ Map, Set } import scala.language.postfixOps +import scala.util.control.TailCalls._ /** * Trait of algorithms that perform variable elimination. @@ -123,19 +124,32 @@ trait VariableElimination[T] extends FactoredAlgorithm[T] with OneTime { } } - @tailrec - protected final def eliminateInOrder( + // Wraps the TailRec class and returns the result + protected def eliminateInOrder( order: List[Variable[_]], factors: MultiSet[Factor[T]], - map: FactorMap[T]): MultiSet[Factor[T]] = + map: FactorMap[T]): MultiSet[Factor[T]] = { + callEliminateInOrder(order, factors, map).result + } + + /* + * TailRec class turns a tail-recursive method into a while loop + * The result needs to be extracted explicitly + */ + private def callEliminateInOrder( + order: List[Variable[_]], + factors: MultiSet[Factor[T]], + map: FactorMap[T]): TailRec[MultiSet[Factor[T]]] = { order match { case Nil => - factors + done(factors) case first :: rest => eliminate(first, factors, map) - eliminateInOrder(rest, factors, map) + tailcall(callEliminateInOrder(rest, factors, map)) } - + } + + private[figaro] def ve(): Unit = { //expand() val (neededElements, _) = getNeededElements(starterElements, Int.MaxValue) From 58ae42b00f63597269f5403e060e031502b888fa Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Fri, 10 Jul 2015 10:31:42 -0400 Subject: [PATCH 16/18] Version update --- Figaro/META-INF/MANIFEST.MF | 2 +- Figaro/figaro_build.properties | 2 +- FigaroExamples/META-INF/MANIFEST.MF | 4 ++-- project/Build.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Figaro/META-INF/MANIFEST.MF b/Figaro/META-INF/MANIFEST.MF index 87dbb207..a7e60963 100644 --- a/Figaro/META-INF/MANIFEST.MF +++ b/Figaro/META-INF/MANIFEST.MF @@ -2,7 +2,7 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: Figaro Bundle-SymbolicName: com.cra.figaro -Bundle-Version: 4.0.0 +Bundle-Version: 3.3.0 Export-Package: com.cra.figaro.algorithm, com.cra.figaro.algorithm.decision, com.cra.figaro.algorithm.decision.index, diff --git a/Figaro/figaro_build.properties b/Figaro/figaro_build.properties index 9052a6eb..3ae169ff 100644 --- a/Figaro/figaro_build.properties +++ b/Figaro/figaro_build.properties @@ -1 +1 @@ -version=4.0.0.0 +version=3.3.0.0 diff --git a/FigaroExamples/META-INF/MANIFEST.MF b/FigaroExamples/META-INF/MANIFEST.MF index 07c9c94c..7b360a6b 100644 --- a/FigaroExamples/META-INF/MANIFEST.MF +++ b/FigaroExamples/META-INF/MANIFEST.MF @@ -2,8 +2,8 @@ Manifest-Version: 1.0 Bundle-ManifestVersion: 2 Bundle-Name: FigaroExamples Bundle-SymbolicName: com.cra.figaro.examples -Bundle-Version: 4.0.0 -Require-Bundle: com.cra.figaro;bundle-version="4.0.0", +Bundle-Version: 3.3.0 +Require-Bundle: com.cra.figaro;bundle-version="3.3.0", org.scala-lang.scala-library Bundle-Vendor: Charles River Analytics Bundle-RequiredExecutionEnvironment: JavaSE-1.6 diff --git a/project/Build.scala b/project/Build.scala index f3bd227d..234eabe0 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -24,7 +24,7 @@ object FigaroBuild extends Build { override val settings = super.settings ++ Seq( organization := "com.cra.figaro", description := "Figaro: a language for probablistic programming", - version := "4.0.0.0", + version := "3.3.0.0", scalaVersion := "2.11.6", crossPaths := true, publishMavenStyle := true, From db3abe4ad977c1c83cea84e001497548b861339a Mon Sep 17 00:00:00 2001 From: bruttenberg Date: Mon, 13 Jul 2015 16:42:23 -0400 Subject: [PATCH 17/18] Fixes for two small bugs. ElementSampler was wiping temporary elements during Factor sampling creation, leading to only a single sample taken. Forward sampling was setting the value of observed elements instead of generating them fresh, leading to poor sampling in MH. --- .../algorithm/sampling/ElementSampler.scala | 2 +- .../figaro/algorithm/sampling/Forward.scala | 1 + .../sampling/LikelihoodWeighter.scala | 6 +- .../test/algorithm/factored/FactorTest.scala | 280 +++++++++--------- 4 files changed, 152 insertions(+), 137 deletions(-) diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala index 423725bd..b77835f9 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/ElementSampler.scala @@ -26,7 +26,6 @@ abstract class ElementSampler(target: Element[_]) extends BaseUnweightedSampler( def sample(): (Boolean, Sample) = { Forward(target) - universe.clearTemporaries (true, Map[Element[_], Any](target -> target.value)) } @@ -78,6 +77,7 @@ class OneTimeElementSampler(target: Element[_], myNumSamples: Int) doInitialize() super.run() update + universe.clearTemporaries } } diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala index 8ae576e2..dffd0737 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/Forward.scala @@ -21,6 +21,7 @@ import com.cra.figaro.algorithm.sampling.LikelihoodWeighter class ForwardWeighter(universe: Universe, cache: Cache) extends LikelihoodWeighter(universe, cache) { override def rejectionAction() = () + override def setObservation(element: Element[_], obs: Option[_]) = {} } /** diff --git a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala index 4fcd0c79..88f8eb3b 100644 --- a/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala +++ b/Figaro/src/main/scala/com/cra/figaro/algorithm/sampling/LikelihoodWeighter.scala @@ -186,12 +186,12 @@ class LikelihoodWeighter(universe: Universe, cache: Cache) { } else { element match { case f: CompoundFlip => { - element.value = obs.get.asInstanceOf[element.Value] + setObservation(element, obs) if (obs.get.asInstanceOf[Boolean]) currentWeight + math.log(f.prob.value) else currentWeight + math.log(1 - f.prob.value) } case e: HasDensity[_] => { - element.value = obs.get.asInstanceOf[element.Value] + setObservation(element, obs) val density = element.asInstanceOf[HasDensity[element.Value]].density(obs.asInstanceOf[Option[element.Value]].get) currentWeight + math.log(density) } @@ -204,6 +204,8 @@ class LikelihoodWeighter(universe: Universe, cache: Cache) { nextWeight + element.constraint(element.value) } + protected def setObservation(element: Element[_], obs: Option[_]) = element.value = obs.get.asInstanceOf[element.Value] + /* Action to take on a rejection. By default it throws an Importance.Reject exception, but this can be overriden for another behavior */ protected def rejectionAction(): Unit = throw Importance.Reject diff --git a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/FactorTest.scala b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/FactorTest.scala index 00731b05..37139d61 100644 --- a/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/FactorTest.scala +++ b/Figaro/src/test/scala/com/cra/figaro/test/algorithm/factored/FactorTest.scala @@ -42,6 +42,8 @@ import com.cra.figaro.library.atomic.continuous.Normal import com.cra.figaro.library.atomic.continuous.Uniform import com.cra.figaro.library.compound.CPD import com.cra.figaro.algorithm.factored.ParticleGenerator +import com.cra.figaro.library.compound.If +import com.cra.figaro.algorithm.factored.VariableElimination class FactorTest extends WordSpec with Matchers with PrivateMethodTester { @@ -68,20 +70,20 @@ class FactorTest extends WordSpec with Matchers with PrivateMethodTester { Universe.createNew() val e1 = Flip(0.2) Values()(e1) - val v1 = Variable(e1).id + val v1 = Variable(e1).id Variable.clearCache LazyValues.clear(Universe.universe) Values()(e1) val v2 = Variable(e1).id v1 should equal(v2) } - + "always be equal to a variable with the same id" in { Universe.createNew() val e1 = Flip(0.2) Values()(e1) val v1 = Variable(e1) - val v2 = new Variable(ValueSet.withStar(Set[Boolean]())) {override val id = v1.id} + val v2 = new Variable(ValueSet.withStar(Set[Boolean]())) { override val id = v1.id } v1 == v2 should equal(true) } @@ -118,34 +120,34 @@ class FactorTest extends WordSpec with Matchers with PrivateMethodTester { f.set(indices, 0.3) f.get(indices) should equal(0.3) } - + "get updated set of factors for an element when the factors have been updated" in { Universe.createNew() val v1 = Flip(0.5) Values()(v1) val f1 = Factory.make(v1)(0) - val f1mod = f1.mapTo((d: Double) => 2.0*d) + val f1mod = f1.mapTo((d: Double) => 2.0 * d) Factory.updateFactor(v1, List(f1mod)) Factory.make(v1)(0).get(List(0)) should equal(f1mod.get(List(0))) - } - -// "have the first index List be all zeros" in { -// Universe.createNew() -// val e1 = Flip(0.1) -// val e2 = Constant(8) -// val e3 = Select(0.2 -> "a", 0.3 -> "b", 0.5 -> "c") -// val e4 = Flip(0.7) -// Values()(e1) -// Values()(e2) -// Values()(e3) -// Values()(e4) -// val v1 = Variable(e1) -// val v2 = Variable(e2) -// val v3 = Variable(e3) -// val v4 = Variable(e4) -// val f = Factory.simpleMake[Double](List(v1, v2, v3, v4)) -// f.firstIndices should equal(List(0, 0, 0, 0)) -// } + } + + // "have the first index List be all zeros" in { + // Universe.createNew() + // val e1 = Flip(0.1) + // val e2 = Constant(8) + // val e3 = Select(0.2 -> "a", 0.3 -> "b", 0.5 -> "c") + // val e4 = Flip(0.7) + // Values()(e1) + // Values()(e2) + // Values()(e3) + // Values()(e4) + // val v1 = Variable(e1) + // val v2 = Variable(e2) + // val v3 = Variable(e3) + // val v4 = Variable(e4) + // val f = Factory.simpleMake[Double](List(v1, v2, v3, v4)) + // f.firstIndices should equal(List(0, 0, 0, 0)) + // } "have the next index List carry and add correctly" in { Universe.createNew() @@ -354,8 +356,8 @@ class FactorTest extends WordSpec with Matchers with PrivateMethodTester { f.set(List(0, 0, 1), 0.3) f.set(List(1, 0, 1), 0.4) f.set(List(2, 0, 1), 0.5) - val g = f.recordArgMax(v3.asInstanceOf[Variable[Any]], - (x: Double, y: Double) => x < y) + val g = f.recordArgMax(v3.asInstanceOf[Variable[Any]], + (x: Double, y: Double) => x < y) g.variables should equal(List(v1, v2)) g.get(List(0, 0)) should equal(true) g.get(List(1, 0)) should equal(false) @@ -610,7 +612,7 @@ class FactorTest extends WordSpec with Matchers with PrivateMethodTester { for { i <- 0 to 1; j <- 0 to 1 } v1Factor.get(List(v2Index, i, j)) should equal(1.0) v2Factor.get(List(v2Index, 0, v3FalseIndex)) should equal(1.0) v2Factor.get(List(v2Index, 0, v3TrueIndex)) should equal(0.0) - for { i <- 0 to 1} v2Factor.get(List(v1Index, 0, i)) should equal(1.0) + for { i <- 0 to 1 } v2Factor.get(List(v1Index, 0, i)) should equal(1.0) } } @@ -664,119 +666,129 @@ class FactorTest extends WordSpec with Matchers with PrivateMethodTester { "given an atomic not in the factor" should { "automatically sample the element" in { - Universe.createNew() + Universe.createNew() val v1 = Normal(0.0, 1.0) Values()(v1) val factor = Factory.make(v1) factor(0).size should equal(ParticleGenerator.defaultArgSamples) - factor(0).get(List(0)) should equal(1.0/ParticleGenerator.defaultArgSamples) + factor(0).get(List(0)) should equal(1.0 / ParticleGenerator.defaultArgSamples) + } + + "correctly create factors for continuous elements through chains" in { + val uni = Universe.createNew() + val elem = If(Flip(0.3), Uniform(0.0, 1.0), Uniform(1.0, 2.0)) + ParticleGenerator(uni) + val alg = VariableElimination(elem) + alg.start() + alg.distribution(elem).toList.size should be (14) + } } - -// "given a chain" should { -// "produce a conditional selector for each parent value" in { -// Universe.createNew() -// val v1 = Flip(0.2) -// val v2 = Select(0.1 -> 1, 0.9 -> 2) -// val v3 = Constant(3) -// val v4 = Chain(v1, (b: Boolean) => if (b) v2; else v3) -// Values()(v4) -// val v1Vals = Variable(v1).range -// val v2Vals = Variable(v2).range -// val v4Vals = Variable(v4).range -// val v1t = v1Vals indexOf Regular(true) -// val v1f = v1Vals indexOf Regular(false) -// val v21 = v2Vals indexOf Regular(1) -// val v22 = v2Vals indexOf Regular(2) -// val v41 = v4Vals indexOf Regular(1) -// val v42 = v4Vals indexOf Regular(2) -// val v43 = v4Vals indexOf Regular(3) -// -// val factor = Factory.make(v4) -// val List(v4Factor) = Factory.combineFactors(factor, SumProductSemiring, true) -// -// v4Factor.get(List(v1t, v21, 0, v41)) should equal(1.0) -// v4Factor.get(List(v1t, v22, 0, v41)) should equal(0.0) -// v4Factor.get(List(v1t, v21, 0, v42)) should equal(0.0) -// v4Factor.get(List(v1t, v22, 0, v42)) should equal(1.0) -// v4Factor.get(List(v1t, v21, 0, v43)) should equal(0.0) -// v4Factor.get(List(v1t, v22, 0, v43)) should equal(0.0) -// v4Factor.get(List(v1f, v21, 0, v41)) should equal(0.0) -// v4Factor.get(List(v1f, v22, 0, v41)) should equal(0.0) -// v4Factor.get(List(v1f, v21, 0, v42)) should equal(0.0) -// v4Factor.get(List(v1f, v22, 0, v42)) should equal(0.0) -// v4Factor.get(List(v1f, v21, 0, v43)) should equal(1.0) -// v4Factor.get(List(v1f, v22, 0, v43)) should equal(1.0) -// -// } -// -// "produce a conditional selector for each non-temporary parent value" in { -// Universe.createNew() -// val v1 = Flip(0.2) -// val v4 = Chain(v1, (b: Boolean) => if (b) Select(0.1 -> 1, 0.9 -> 2); else Constant(3)) -// Values()(v4) -// val v1Vals = Variable(v1).range -// val v4Vals = Variable(v4).range -// -// val v1t = v1Vals indexOf Regular(true) -// val v1f = v1Vals indexOf Regular(false) -// val v41 = v4Vals indexOf Regular(1) -// val v42 = v4Vals indexOf Regular(2) -// val v43 = v4Vals indexOf Regular(3) -// -// val factor = Factory.make(v4) -// val List(v4Factor) = Factory.combineFactors(factor, SumProductSemiring, true) -// -// v4Factor.get(List(v1t, v41)) should equal(0.1) -// v4Factor.get(List(v1t, v42)) should equal(0.9) -// v4Factor.get(List(v1t, v43)) should equal(0.0) -// v4Factor.get(List(v1f, v41)) should equal(0.0) -// v4Factor.get(List(v1f, v42)) should equal(0.0) -// v4Factor.get(List(v1f, v43)) should equal(1.0) -// } -// } - -// "given a CPD with one argument" should { -// "produce a single factor with a case for each parent value" in { -// Universe.createNew() -// val v1 = Flip(0.2) -// -// val v2 = CPD(v1, false -> Flip(0.1), true -> Flip(0.7)) -// Values()(v2) -// -// val v1Vals = Variable(v1).range -// val v2Vals = Variable(v2).range -// -// val v1t = v1Vals indexOf Regular(true) -// val v1f = v1Vals indexOf Regular(false) -// val v2t = v2Vals indexOf Regular(true) -// val v2f = v2Vals indexOf Regular(false) -// val v3t = 0 -// val v3f = 1 -// val v4t = 0 -// val v4f = 1 -// -// val factor = Factory.make(v2) -// val List(v2Factor) = Factory.combineFactors(factor, SumProductSemiring, true) -// -// v2Factor.get(List(v1t, v3t, v4t, v2t)) should equal(1.0) -// v2Factor.get(List(v1t, v3t, v4f, v2t)) should equal(1.0) -// v2Factor.get(List(v1t, v3f, v4t, v2t)) should equal(0.0) -// v2Factor.get(List(v1t, v3f, v4f, v2t)) should equal(0.0) -// v2Factor.get(List(v1t, v3t, v4t, v2f)) should equal(0.0) -// v2Factor.get(List(v1t, v3t, v4f, v2f)) should equal(0.0) -// v2Factor.get(List(v1t, v3f, v4t, v2f)) should equal(1.0) -// v2Factor.get(List(v1t, v3f, v4f, v2f)) should equal(1.0) -// v2Factor.get(List(v1f, v3t, v4t, v2t)) should equal(1.0) -// v2Factor.get(List(v1f, v3t, v4f, v2t)) should equal(0.0) -// v2Factor.get(List(v1f, v3f, v4t, v2t)) should equal(1.0) -// v2Factor.get(List(v1f, v3f, v4f, v2t)) should equal(0.0) -// v2Factor.get(List(v1f, v3t, v4t, v2f)) should equal(0.0) -// v2Factor.get(List(v1f, v3t, v4f, v2f)) should equal(1.0) -// v2Factor.get(List(v1f, v3f, v4t, v2f)) should equal(0.0) -// v2Factor.get(List(v1f, v3f, v4f, v2f)) should equal(1.0) -// } -// } + + // "given a chain" should { + // "produce a conditional selector for each parent value" in { + // Universe.createNew() + // val v1 = Flip(0.2) + // val v2 = Select(0.1 -> 1, 0.9 -> 2) + // val v3 = Constant(3) + // val v4 = Chain(v1, (b: Boolean) => if (b) v2; else v3) + // Values()(v4) + // val v1Vals = Variable(v1).range + // val v2Vals = Variable(v2).range + // val v4Vals = Variable(v4).range + // val v1t = v1Vals indexOf Regular(true) + // val v1f = v1Vals indexOf Regular(false) + // val v21 = v2Vals indexOf Regular(1) + // val v22 = v2Vals indexOf Regular(2) + // val v41 = v4Vals indexOf Regular(1) + // val v42 = v4Vals indexOf Regular(2) + // val v43 = v4Vals indexOf Regular(3) + // + // val factor = Factory.make(v4) + // val List(v4Factor) = Factory.combineFactors(factor, SumProductSemiring, true) + // + // v4Factor.get(List(v1t, v21, 0, v41)) should equal(1.0) + // v4Factor.get(List(v1t, v22, 0, v41)) should equal(0.0) + // v4Factor.get(List(v1t, v21, 0, v42)) should equal(0.0) + // v4Factor.get(List(v1t, v22, 0, v42)) should equal(1.0) + // v4Factor.get(List(v1t, v21, 0, v43)) should equal(0.0) + // v4Factor.get(List(v1t, v22, 0, v43)) should equal(0.0) + // v4Factor.get(List(v1f, v21, 0, v41)) should equal(0.0) + // v4Factor.get(List(v1f, v22, 0, v41)) should equal(0.0) + // v4Factor.get(List(v1f, v21, 0, v42)) should equal(0.0) + // v4Factor.get(List(v1f, v22, 0, v42)) should equal(0.0) + // v4Factor.get(List(v1f, v21, 0, v43)) should equal(1.0) + // v4Factor.get(List(v1f, v22, 0, v43)) should equal(1.0) + // + // } + // + // "produce a conditional selector for each non-temporary parent value" in { + // Universe.createNew() + // val v1 = Flip(0.2) + // val v4 = Chain(v1, (b: Boolean) => if (b) Select(0.1 -> 1, 0.9 -> 2); else Constant(3)) + // Values()(v4) + // val v1Vals = Variable(v1).range + // val v4Vals = Variable(v4).range + // + // val v1t = v1Vals indexOf Regular(true) + // val v1f = v1Vals indexOf Regular(false) + // val v41 = v4Vals indexOf Regular(1) + // val v42 = v4Vals indexOf Regular(2) + // val v43 = v4Vals indexOf Regular(3) + // + // val factor = Factory.make(v4) + // val List(v4Factor) = Factory.combineFactors(factor, SumProductSemiring, true) + // + // v4Factor.get(List(v1t, v41)) should equal(0.1) + // v4Factor.get(List(v1t, v42)) should equal(0.9) + // v4Factor.get(List(v1t, v43)) should equal(0.0) + // v4Factor.get(List(v1f, v41)) should equal(0.0) + // v4Factor.get(List(v1f, v42)) should equal(0.0) + // v4Factor.get(List(v1f, v43)) should equal(1.0) + // } + // } + + // "given a CPD with one argument" should { + // "produce a single factor with a case for each parent value" in { + // Universe.createNew() + // val v1 = Flip(0.2) + // + // val v2 = CPD(v1, false -> Flip(0.1), true -> Flip(0.7)) + // Values()(v2) + // + // val v1Vals = Variable(v1).range + // val v2Vals = Variable(v2).range + // + // val v1t = v1Vals indexOf Regular(true) + // val v1f = v1Vals indexOf Regular(false) + // val v2t = v2Vals indexOf Regular(true) + // val v2f = v2Vals indexOf Regular(false) + // val v3t = 0 + // val v3f = 1 + // val v4t = 0 + // val v4f = 1 + // + // val factor = Factory.make(v2) + // val List(v2Factor) = Factory.combineFactors(factor, SumProductSemiring, true) + // + // v2Factor.get(List(v1t, v3t, v4t, v2t)) should equal(1.0) + // v2Factor.get(List(v1t, v3t, v4f, v2t)) should equal(1.0) + // v2Factor.get(List(v1t, v3f, v4t, v2t)) should equal(0.0) + // v2Factor.get(List(v1t, v3f, v4f, v2t)) should equal(0.0) + // v2Factor.get(List(v1t, v3t, v4t, v2f)) should equal(0.0) + // v2Factor.get(List(v1t, v3t, v4f, v2f)) should equal(0.0) + // v2Factor.get(List(v1t, v3f, v4t, v2f)) should equal(1.0) + // v2Factor.get(List(v1t, v3f, v4f, v2f)) should equal(1.0) + // v2Factor.get(List(v1f, v3t, v4t, v2t)) should equal(1.0) + // v2Factor.get(List(v1f, v3t, v4f, v2t)) should equal(0.0) + // v2Factor.get(List(v1f, v3f, v4t, v2t)) should equal(1.0) + // v2Factor.get(List(v1f, v3f, v4f, v2t)) should equal(0.0) + // v2Factor.get(List(v1f, v3t, v4t, v2f)) should equal(0.0) + // v2Factor.get(List(v1f, v3t, v4f, v2f)) should equal(1.0) + // v2Factor.get(List(v1f, v3f, v4t, v2f)) should equal(0.0) + // v2Factor.get(List(v1f, v3f, v4f, v2f)) should equal(1.0) + // } + // } "given an apply of one argument" should { "produce a factor that matches the argument to the result via the function" in { From b95b0485cee9188cabead4a2c5586cf32bbb46a3 Mon Sep 17 00:00:00 2001 From: Mike Reposa Date: Wed, 22 Jul 2015 14:21:58 -0400 Subject: [PATCH 18/18] Fix broken links --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fbb6e671..13abb06a 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,4 @@ Figaro is a probabilistic programming language that supports development of very Figaro makes it possible to express probabilistic models using the power of programming languages, giving the modeler the expressive tools to create a wide variety of models. Figaro comes with a number of built-in reasoning algorithms that can be applied automatically to new models. In addition, Figaro models are data structures in the Scala programming language, which is interoperable with Java, and can be constructed, manipulated, and used directly within any Scala or Java program. -Figaro is free and is released under an [open-source license](https://github.com/p2t2/figaro/blob/master/LICENSE). The current, stable binary release of Figaro can be found [here](https://www.cra.com/work/case-studies/figaro). For more information please see the [Figaro Release Notes](https://www.cra.com/sites/default/files/pdf/Figaro-Release-Notes.pdf) and [Figaro Tutorial](https://www.cra.com/sites/default/files/pdf/Figaro-Tutorial.pdf). Documentation of the Figaro library interface can be found [here](https://www.cra.com/Figaro_Scaladoc/index.html#package). +Figaro is free and is released under an [open-source license](https://github.com/p2t2/figaro/blob/master/LICENSE). The current, stable binary release of Figaro can be found [here](https://www.cra.com/work/case-studies/figaro). For more information please see the [Figaro Release Notes](https://www.cra.com/sites/default/files/pdf/Figaro_Release_Notes.pdf) and [Figaro Tutorial](https://www.cra.com/sites/default/files/pdf/Figaro_Tutorial.pdf). Documentation of the Figaro library interface can be found [here](https://www.cra.com/Figaro_Scaladoc/index.html#package).