Skip to content

Commit

Permalink
Merge pull request #57 from dityas/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
dityas authored Jan 26, 2020
2 parents 0ee6196 + dc08993 commit c91524f
Show file tree
Hide file tree
Showing 18 changed files with 1,344 additions and 524 deletions.
1 change: 1 addition & 0 deletions Protos/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ bin/
*.pomdp
target/
/test.log
/test.obj
Binary file modified Protos/build/Protos.jar
Binary file not shown.
81 changes: 81 additions & 0 deletions Protos/src/thinclab/belief/IBeliefOps.java
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,87 @@ public DD beliefUpdate(
}


public DD differentBeliefUpdate(
DD belief, String action, String[] observations) throws ZeroProbabilityObsException {

/*
* Level 1 belief update
*
* P(S, Mj| Oi'=o) =
* norm x Sumout[S, Mj, Thetaj, Aj]
* f(S, Mj) x f(Thetaj, Mj) x f(Aj, Mj)
* x f(S', S, Aj) x f(Oi'=o, S', Aj)
* x Sumout[Oj']
* f(Oj', Aj, Thetaj, S') x f(Mj', Mj, Aj, Oj')
*/

IPOMDP DPRef = this.getIPOMDP();

/* First reduce Oi based on observations */
int[] obsVals =
new int[DPRef.Omega.size() - DPRef.OmegaJNames.size()];

for (int o = 0; o < obsVals.length; o++) {
int val = DPRef.findObservationByName(o, observations[o]) + 1;

if (val < 0) {
LOGGER.error(
"Obs Variable " + DPRef.Omega.get(o).name
+ " does not take value " + observations[o]);
System.exit(-1);
}

else obsVals[o] = val;
}

/* Restrict Oi */
DD[] restrictedOi =
OP.restrictN(
DPRef.currentOi.get(action),
IPOMDP.stackArray(
DPRef.obsIVarPrimeIndices, obsVals));

/* Collect f1 = P(S, Mj) */
DD f1 = OP.mult(belief, DPRef.currentTau);

/* Collect f2 = P(Aj | Mj) x P(Thetaj| Mj) x P(Oi'=o, S', Aj) x P (S', Aj, S) */
DD[] f2 =
ArrayUtils.addAll(
ArrayUtils.addAll(
DPRef.currentTi.get(action),
new DD[] {DPRef.currentAjGivenMj, DPRef.currentThetajGivenMj}),
restrictedOi);

// /* Get TAU */
// DD tau = DPRef.currentTau;

/* Perform the sum out */
DD nextBelief =
OP.addMultVarElim(
ArrayUtils.add(f2, f1),
DPRef.stateVarIndices);

/* Shift indices */
nextBelief = OP.primeVars(nextBelief, -(DPRef.S.size() + DPRef.Omega.size()));

/* compute normalization factor */
DD norm =
OP.addMultVarElim(
nextBelief,
ArrayUtils.subarray(
DPRef.stateVarIndices,
0,
DPRef.thetaVarPosition));

if (norm.getVal() < 1e-8)
throw new ZeroProbabilityObsException(
"Observation " + Arrays.toString(observations)
+ " not possible at belief " + belief);

return OP.div(nextBelief, norm);
}


public DD beliefUpdate(
DD belief, String action, int[][] obsVals) throws ZeroProbabilityObsException {

Expand Down
16 changes: 14 additions & 2 deletions Protos/src/thinclab/belief/SSGABeliefExpansion.java
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,20 @@ public void expand() {
}

/* Add belief point if it doesn't already exist */
if (!this.exploredBeliefs.contains(nextBelief))
this.exploredBeliefs.add(nextBelief);
if (!this.exploredBeliefs.contains(nextBelief)) {

double minDist = Double.MAX_VALUE;
for (DD bel : this.exploredBeliefs) {

double dist = OP.maxAll(OP.abs(OP.sub(bel, nextBelief)));

if (dist < minDist) minDist = dist;

}

if (minDist > 0.01)
this.exploredBeliefs.add(nextBelief);
}

belief = nextBelief;
}
Expand Down
2 changes: 1 addition & 1 deletion Protos/src/thinclab/decisionprocesses/IPOMDP.java
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ public HashMap<String, DD> makeRi() {

LOGGER.debug("For Ai=" + Ai + " R(S,Mj) has vars "
+ Arrays.toString(RSMj.getVarSet()));
LOGGER.debug("Ri is: " + RSMj.toDDTree());
// LOGGER.debug("Ri is: " + RSMj.toDDTree());

Ri.put(Ai, RSMj);
}
Expand Down
25 changes: 24 additions & 1 deletion Protos/src/thinclab/executables/RunSimulations.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.apache.commons.cli.Options;
import org.apache.log4j.Logger;

import thinclab.belief.BeliefRegionExpansionStrategy;
import thinclab.belief.SSGABeliefExpansion;
import thinclab.belief.SparseFullBeliefExpansion;
import thinclab.decisionprocesses.IPOMDP;
import thinclab.decisionprocesses.POMDP;
Expand Down Expand Up @@ -65,6 +67,13 @@ public static void main(String[] args) {
/* simulation rounds */
opt.addOption("y", true, "number of simulation rounds");

/* Use SSGA? */
opt.addOption(
"e",
"ssga",
false,
"use SSGA expansion? (5 perseus rounds and 10 iterations of exploration)");

/* simulation switch */
opt.addOption(
"x",
Expand Down Expand Up @@ -154,6 +163,7 @@ else if (line.hasOption("i")) {

/* set NextBelState Caching */
NextBelStateCache.useCache();
NextBelStateCache.setDB("/tmp/nz_cache.db");

LOGGER.info("Simulating IPOMDP...");

Expand Down Expand Up @@ -217,10 +227,23 @@ else if (line.hasOption("i")) {
/* set context back to IPOMDP */
ipomdp.setGlobals();

BeliefRegionExpansionStrategy BE;
int numRounds = 1;

if (line.hasOption("e")) {
BE = new SSGABeliefExpansion(ipomdp, 10);
numRounds = 5;
}

else {
BE = new SparseFullBeliefExpansion(ipomdp, 10);
numRounds = 1;
}

OnlineInteractiveSymbolicPerseus solver =
new OnlineInteractiveSymbolicPerseus(
ipomdp,
new SparseFullBeliefExpansion(ipomdp, 10), 1, backups);
BE, numRounds, backups);

StochasticSimulation ss = new StochasticSimulation(solver, simLength);
ss.runSimulation();
Expand Down
3 changes: 2 additions & 1 deletion Protos/src/thinclab/legacy/AlphaVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import thinclab.exceptions.VariableNotFoundException;
import thinclab.exceptions.ZeroProbabilityObsException;
import thinclab.utils.Diagnostics;
import thinclab.utils.NextBelStateCache;

public class AlphaVector implements Serializable {
/**
Expand Down Expand Up @@ -62,7 +63,7 @@ public static AlphaVector dpBackup(

smallestProb = ipomdp.tolerance / maxAbsVal;
nextBelStates =
NextBelState.oneStepNZPrimeBelStates(
NextBelState.oneStepNZPrimeBelStatesCached(
ipomdp,
belState,
true,
Expand Down
2 changes: 2 additions & 0 deletions Protos/src/thinclab/legacy/Global.java
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ public static void clearHashtables() {
Global.nNodesHashtable.clear();
Global.leafHashtable.put(DD.zero, new WeakReference<DD>(DD.zero));
Global.leafHashtable.put(DD.one, new WeakReference<DD>(DD.one));

System.gc();
}

public static void newHashtables() {
Expand Down
128 changes: 81 additions & 47 deletions Protos/src/thinclab/legacy/NextBelState.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
package thinclab.legacy;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -26,7 +27,9 @@
* @author adityas
*
*/
public class NextBelState {
public class NextBelState implements Serializable {

private static final long serialVersionUID = 8843718772195892772L;

public DD[][] nextBelStates;
public int[] nzObsIds;
Expand Down Expand Up @@ -231,6 +234,83 @@ public void getObsStrat() {
* Functions to get single step NextBelStates
*/

public static HashMap<String, NextBelState> oneStepNZPrimeBelStatesCached(
IPOMDP ipomdp,
DD belState,
boolean normalize,
double smallestProb) throws ZeroProbabilityObsException, VariableNotFoundException {

/* if already computed, recover from cache */
if (NextBelStateCache.cachingAllowed()) {
HashMap<String, NextBelState> cachedEntry =
NextBelStateCache.getCachedEntry(belState);

if (cachedEntry != null) return cachedEntry;
}

/* else compute and cache */
HashMap<String, NextBelState> nzPrimes =
oneStepNZPrimeBelStates2(ipomdp, belState, normalize, smallestProb);

if (NextBelStateCache.cachingAllowed()) {
NextBelStateCache.cacheNZPrime(belState, nzPrimes);
}

return nzPrimes;

}

public static HashMap<String, NextBelState> oneStepNZPrimeBelStates2(
IPOMDP ipomdp,
DD belState,
boolean normalize,
double smallestProb) throws ZeroProbabilityObsException, VariableNotFoundException {

/*
* Computes NextBelStates according to the IBeliefOps methods instead of
* the original implementation
*/

HashMap<String, NextBelState> nextBelStates = new HashMap<String, NextBelState>();

for (String act: ipomdp.getActions()) {

List<DD[]> nextBelStatesForAct = new ArrayList<DD[]>();

List<List<String>> allObs = ipomdp.getAllPossibleObservations();
DD obsDist = ipomdp.getObsDist(ipomdp.getCurrentBelief(), act);
double[] obsProbs = OP.convert2array(obsDist, ipomdp.obsIVarPrimeIndices);

for (int o = 0; o < allObs.size(); o++) {

DD nextBelief =
ipomdp.beliefUpdate(
ipomdp.getCurrentBelief(),
act,
allObs.get(o).stream().toArray(String[]::new));

DD[] factoredNextBel = ipomdp.factorBelief(nextBelief);
factoredNextBel =
OP.primeVarsN(factoredNextBel, ipomdp.S.size() + ipomdp.Omega.size());

factoredNextBel =
ArrayUtils.add(
factoredNextBel,
DDleaf.myNew(obsProbs[o]));

nextBelStatesForAct.add(factoredNextBel);
}

DD[][] nextBelStatesFactors = nextBelStatesForAct.stream().toArray(DD[][]::new);

NextBelState nbState = new NextBelState(ipomdp, obsProbs, 1e-8);
nbState.nextBelStates = nextBelStatesFactors;
nextBelStates.put(act, nbState);
}

return nextBelStates;
}

public static HashMap<String, NextBelState> oneStepNZPrimeBelStates(
IPOMDP ipomdp,
DD belState,
Expand All @@ -240,11 +320,6 @@ public static HashMap<String, NextBelState> oneStepNZPrimeBelStates(
* Computes the next belief states and the observation probabilities that lead to them
*/

if (NextBelStateCache.cachingAllowed()
&& NextBelStateCache.NEXT_BELSTATE_CACHE.containsKey(belState))
return NextBelStateCache.NEXT_BELSTATE_CACHE.get(belState);

/* get the precomputed possible observation value indices from the IPOMDP */
int[][] obsConfig = ipomdp.obsCombinationsIndices;

double[] obsProbs;
Expand Down Expand Up @@ -283,49 +358,8 @@ public static HashMap<String, NextBelState> oneStepNZPrimeBelStates(
nextBelStates.get(Ai).restrictN(marginals, obsConfig);
// logger.debug("After computing marginals " + nextBelStates[actId]);

// if (!cacheHit && Global.USE_NEXT_BELSTATE_CACHES) {
//// LOGGER.debug("Building cache for action " + Ai + " at belief"
//// + ipomdp.toMapWithTheta(belState));
// nextBelStateCache.put(Ai, nextBelStates.get(Ai).nextBelStates);
// }
//
// else {
//
// /* verify cache */
//// LOGGER.debug("Verifying cache Hit");
//// LOGGER.debug("For action " + Ai);
//// LOGGER.debug("Computed bel states are " + nextBelStates.get(Ai).nextBelStates.length);
//// LOGGER.debug("cached bel states are " + Global.NEXT_BELSTATES_CACHE.get(belState).get(Ai).length);
// for (int i = 0; i < obsConfig.length; i++) {
//
// DD[] computedNextBelStates = nextBelStates.get(Ai).nextBelStates[i];
// DD[] cachedNextBelStates =
// Global.NEXT_BELSTATES_CACHE.get(belState).get(Ai)[i];
//
// for (int s = 0; s < computedNextBelStates.length; s++) {
//// LOGGER.debug("Original " + computedNextBelStates[s].toDDTree());
//// LOGGER.debug("Cached" + cachedNextBelStates[s].toDDTree());
//
// double val =
// OP.maxAll(OP.abs(OP.sub(
// computedNextBelStates[s],
// cachedNextBelStates[s])));
//
// if (val > 0.001) {
// LOGGER.error("Holy shit!");
// LOGGER.error(val);
// }
// }
// }
// }
}

// if (!cacheHit && Global.USE_NEXT_BELSTATE_CACHES) {
//// LOGGER.debug("Storing in global cache");
// HashMap<String, DD[][]> tempCache = new HashMap<String, DD[][]>();
// tempCache.putAll(nextBelStateCache);
// Global.NEXT_BELSTATES_CACHE.put(belState, tempCache);
// }
}

catch (Exception e) {
Expand Down
Loading

0 comments on commit c91524f

Please sign in to comment.