Skip to content

Commit

Permalink
POMCP [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 20, 2024
1 parent b6aa05c commit ad4e9bd
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 84 deletions.
42 changes: 0 additions & 42 deletions examples/manual_play/cyborg_test.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
descr="whether reinvigoration should be used"),
agents_constants.POMCP.INITIAL_BELIEF: HParam(value=b1, name=agents_constants.POMCP.INITIAL_BELIEF,
descr="the initial belief"),
agents_constants.POMCP.PLANNING_TIME: HParam(value=300, name=agents_constants.POMCP.PLANNING_TIME,
agents_constants.POMCP.PLANNING_TIME: HParam(value=2000, name=agents_constants.POMCP.PLANNING_TIME,
descr="the planning time"),
agents_constants.POMCP.MAX_PARTICLES: HParam(value=1000, name=agents_constants.POMCP.MAX_PARTICLES,
descr="the maximum number of belief particles"),
Expand All @@ -61,6 +61,9 @@
descr="the weighting factor for UCB exploration"),
agents_constants.POMCP.LOG_STEP_FREQUENCY: HParam(
value=1, name=agents_constants.POMCP.LOG_STEP_FREQUENCY, descr="frequency of logging time-steps"),
agents_constants.POMCP.MAX_NEGATIVE_SAMPLES: HParam(
value=20, name=agents_constants.POMCP.MAX_NEGATIVE_SAMPLES,
descr="maximum number of negative samples when filling belief particles"),
agents_constants.POMCP.DEFAULT_NODE_VALUE: HParam(
value=-2000, name=agents_constants.POMCP.DEFAULT_NODE_VALUE, descr="the default node value in "
"the search tree"),
Expand All @@ -72,7 +75,7 @@
value=0.95, name=agents_constants.COMMON.CONFIDENCE_INTERVAL,
descr="confidence interval"),
agents_constants.COMMON.MAX_ENV_STEPS: HParam(
value=500, name=agents_constants.COMMON.MAX_ENV_STEPS,
value=100, name=agents_constants.COMMON.MAX_ENV_STEPS,
descr="maximum number of steps in the environment (for envs with infinite horizon generally)"),
agents_constants.COMMON.RUNNING_AVERAGE: HParam(
value=100, name=agents_constants.COMMON.RUNNING_AVERAGE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
descr="the maximum depth for planning"),
agents_constants.POMCP.C: HParam(value=0.35, name=agents_constants.POMCP.C,
descr="the weighting factor for UCB exploration"),
agents_constants.POMCP.MAX_NEGATIVE_SAMPLES: HParam(
value=200, name=agents_constants.POMCP.MAX_NEGATIVE_SAMPLES,
descr="maximum number of negative samples when filling belief particles"),
agents_constants.POMCP.DEFAULT_NODE_VALUE: HParam(
value=-2000, name=agents_constants.POMCP.DEFAULT_NODE_VALUE, descr="the default node value in "
"the search tree"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import numpy as np
from csle_common.dao.simulation_config.base_env import BaseEnv
from csle_common.dao.training.policy import Policy
from csle_common.logging.log import Logger
from csle_agents.agents.pomcp.belief_tree import BeliefTree
from csle_agents.agents.pomcp.belief_node import BeliefNode
from csle_agents.agents.pomcp.action_node import ActionNode
from csle_agents.agents.pomcp.pomcp_util import POMCPUtil
import csle_agents.constants.constants as constants
from csle_common.logging.log import Logger


class POMCP:
Expand Down Expand Up @@ -197,15 +197,23 @@ def get_action(self) -> int:
f"visit count: {a.visit_count}")
return int(max(action_vals)[1])

def update_tree_with_new_samples(self, action: int, observation: int) -> Dict[int, float]:
def update_tree_with_new_samples(self, action_sequence: List[int], observation: int,
max_negative_samples: int = 20) -> Dict[int, float]:
"""
Updates the tree after an action has been selected and a new observation been received
:param action: the action that was executed
:param action_sequence: the action sequence that was executed
:param observation: the observation that was received
:param max_negative_samples: the maximum number of negative samples that can be collected before
trajectory simulation is initialized
:return: the updated belief state
"""
observation = self.env.get_observation_id_from_vector(
observation_vector=self.env.get_observation_from_history(history=[observation]))
root = self.tree.root
if len(action_sequence) == 0:
raise ValueError("Invalid action sequencee")
action = action_sequence[0]

# Since we executed an action we advance the tree and update the root to the the node corresponding to the
# action that was selected
Expand Down Expand Up @@ -241,19 +249,28 @@ def update_tree_with_new_samples(self, action: int, observation: int) -> Dict[in
particle_slots = self.max_particles - len(new_root.particles)
else:
raise ValueError("Invalid root node")
negative_samples_count = 0
if particle_slots > 0:
# fill particles by Monte-Carlo using reject sampling
particles = []
while len(particles) < particle_slots:
if self.verbose:
Logger.__call__().get_logger().info(f"Filling particles {len(particles)}/{particle_slots}")
s = root.sample_state()
self.env.set_state(state=s)
_, r, _, _, info = self.env.step(action)
s_prime = info[constants.COMMON.STATE]
o = info[constants.COMMON.OBSERVATION]
if o == observation:
particles.append(s_prime)
if negative_samples_count >= max_negative_samples:
particles += POMCPUtil.trajectory_simulation_particles(
o=observation, env=self.env, action_sequence=action_sequence, verbose=self.verbose,
num_particles=(particle_slots - len(particles)))
else:
s = root.sample_state()
self.env.set_state(state=s)
_, r, _, _, info = self.env.step(action)
s_prime = info[constants.COMMON.STATE]
o = info[constants.COMMON.OBSERVATION]
if o == observation:
particles.append(s_prime)
negative_samples_count = 0
else:
negative_samples_count += 1
new_root.particles += particles

# We now prune the old root from the tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def hparam_names(self) -> List[str]:
agents_constants.POMCP.A, agents_constants.POMCP.GAMMA,
agents_constants.POMCP.INITIAL_BELIEF, agents_constants.POMCP.PLANNING_TIME,
agents_constants.POMCP.LOG_STEP_FREQUENCY, agents_constants.POMCP.VERBOSE,
agents_constants.POMCP.DEFAULT_NODE_VALUE,
agents_constants.POMCP.DEFAULT_NODE_VALUE, agents_constants.POMCP.MAX_NEGATIVE_SAMPLES,
agents_constants.POMCP.MAX_PARTICLES, agents_constants.POMCP.C, agents_constants.POMCP.MAX_DEPTH,
agents_constants.COMMON.EVAL_BATCH_SIZE, agents_constants.COMMON.CONFIDENCE_INTERVAL,
agents_constants.COMMON.RUNNING_AVERAGE, agents_constants.COMMON.MAX_ENV_STEPS]
Expand All @@ -188,6 +188,7 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
log_steps_frequency = self.experiment_config.hparams[agents_constants.POMCP.LOG_STEP_FREQUENCY].value
verbose = self.experiment_config.hparams[agents_constants.POMCP.VERBOSE].value
default_node_value = self.experiment_config.hparams[agents_constants.POMCP.DEFAULT_NODE_VALUE].value
max_negative_samples = self.experiment_config.hparams[agents_constants.POMCP.MAX_NEGATIVE_SAMPLES].value
max_env_steps = self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value
N = self.experiment_config.hparams[agents_constants.POMCP.N].value
A = self.experiment_config.hparams[agents_constants.POMCP.A].value
Expand All @@ -203,9 +204,8 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,

# Run N episodes
for i in range(N):

# Setup environments
done = False
action_sequence = []
eval_env = gym.make(self.simulation_env_config.gym_env_name, config=config)
train_env: BaseEnv = gym.make(self.simulation_env_config.gym_env_name, config=config)
_, info = eval_env.reset()
Expand All @@ -225,16 +225,19 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
pomcp.solve(max_depth=max_depth)
action = pomcp.get_action()
_, r, done, _, info = eval_env.step(action)
action_sequence.append(action)
s_prime = info[agents_constants.COMMON.STATE]
o = info[agents_constants.COMMON.OBSERVATION]
belief = pomcp.update_tree_with_new_samples(action=action, observation=o)
belief = pomcp.update_tree_with_new_samples(action_sequence=action_sequence, observation=o,
max_negative_samples=max_negative_samples)
R += r
t += 1
if t % log_steps_frequency == 0:
b = list(map(lambda x: belief[x], random.sample(list(belief.keys()), min(10, len(belief.keys())))))
Logger.__call__().get_logger().info(f"[POMCP] t: {t}, a: {action}, r: {r}, o: {o}, "
f"s_prime: {s_prime}, b: {b}")
Logger.__call__().get_logger().info(f"action: {eval_env.action_id_to_type_and_host[action]}")
s = s_prime

if i % self.experiment_config.log_every == 0:
# Logging
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List, Dict, Any
import numpy as np
from csle_agents.agents.pomcp.node import Node
from collections import Counter
from csle_common.logging.log import Logger
from csle_common.dao.simulation_config.base_env import BaseEnv
from csle_agents.agents.pomcp.node import Node
import csle_agents.constants.constants as constants


class POMCPUtil:
Expand Down Expand Up @@ -84,3 +87,35 @@ def ucb_acquisition_function(action: "Node", c: float) -> float:
:return: the acquisition value of the action
"""
return float(action.value + c * POMCPUtil.ucb(action.parent.visit_count, action.visit_count))

@staticmethod
def trajectory_simulation_particles(o: int, env: BaseEnv, action_sequence: List[int], num_particles: int,
verbose: bool = False) -> List[int]:
"""
Performs trajectory simulations to find possible states matching to the given observation
:param o: the observation to match against
:param env: the black-box simulator to sue for generating trajectories
:param action_sequence: the action sequence for the trajectory
:param num_particles: the number of particles to collect
:param verbose: boolean flag indicating whether logging should be verbose or not
:return: the list of particles matching the given observation
"""
particles: List[int] = []
while len(particles) < num_particles:
done = False
_, info = env.reset()
s = info[constants.COMMON.STATE]
t = 0
while not done and t < len(action_sequence):
_, r, done, _, info = env.step(action=action_sequence[t])
sampled_o = info[constants.COMMON.OBSERVATION]
if t == len(action_sequence) - 1 and sampled_o == o:
particles.append(s)
s = info[constants.COMMON.STATE]
t += 1
if verbose:
Logger.__call__().get_logger().info(f"Filling particles {len(particles)}/{num_particles} "
f"through trajectory simulations, "
f"action sequence: {action_sequence}, observation: {o}")
return particles
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ class POMCP:
REINVIGORATION = "reinvigoration"
PLANNING_TIME = "planning_time"
MAX_PARTICLES = "max_particles"
MAX_NEGATIVE_SAMPLES = "max_negative_samples"
C = "c"
MAX_DEPTH = "max_depth"
LOG_STEP_FREQUENCY = "log_step_frequency"
Expand Down
Loading

0 comments on commit ad4e9bd

Please sign in to comment.