Skip to content

Commit

Permalink
pomcp [wip]
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 23, 2024
1 parent 4d32235 commit da199a3
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions examples/manual_play/learn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,30 @@
o, r, done, _, info = csle_cyborg_env.step(action=a)
s_prime = info[constants.ENV_METRICS.STATE]
oid = info[constants.ENV_METRICS.OBSERVATION]
if (s, s_prime, a) not in transition_probabilities:
transition_probabilities[(s, s_prime, a)] = 1
if ",".join([str(s), str(s_prime), str(a)]) not in transition_probabilities:
transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = 1
new_transitions += 1
else:
transition_probabilities[(s, s_prime, a)] = transition_probabilities[(s, s_prime, a)] + 1
if (s, s_prime, a) not in reward_function:
reward_function[(s, s_prime, a)] = r
if (s_prime, oid) not in observation_probabilities:
observation_probabilities[(s_prime, oid)] = 1
transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = transition_probabilities[",".join([str(s), str(s_prime), str(a)])] + 1
if ",".join([str(s), str(s_prime), str(a)]) not in reward_function:
reward_function[",".join([str(s), str(s_prime), str(a)])] = r
if ",".join([str(s_prime), str(oid)]) not in observation_probabilities:
observation_probabilities[",".join([str(s_prime), str(oid)])] = 1
else:
observation_probabilities[(s_prime, oid)] = observation_probabilities[(s_prime, oid)] + 1
observation_probabilities[",".join([str(s_prime), str(oid)])] = observation_probabilities[",".join([str(s_prime), str(oid)])] + 1
t_count += 1
print(f"new transitions: {new_transitions}")

# if i % save_every == 0:
# model = {}
# model["transitions"] = transition_probabilities
# model["rewards"] = reward_function
# model["episodes"] = i
# model["steps"] = t_count
# model["observations"] = observation_probabilities
# model["initial_state"] = initial_state_distribution
# json_str = json.dumps(model, indent=4, sort_keys=True)
# with io.open(f"/home/kim/cyborg_model_{i}.json", 'w', encoding='utf-8') as f:
# f.write(json_str)
if i % save_every == 0:
model = {}
model["transitions"] = transition_probabilities
model["rewards"] = reward_function
model["episodes"] = i
model["steps"] = t_count
model["observations"] = observation_probabilities
model["initial_state"] = initial_state_distribution
json_str = json.dumps(model, indent=4, sort_keys=True)
with io.open(f"/home/kim/cyborg_model_{i}.json", 'w', encoding='utf-8') as f:
f.write(json_str)


0 comments on commit da199a3

Please sign in to comment.