diff --git a/examples/manual_play/learn_model.py b/examples/manual_play/learn_model.py index 0743fb2a2..952cc5b52 100644 --- a/examples/manual_play/learn_model.py +++ b/examples/manual_play/learn_model.py @@ -38,30 +38,30 @@ o, r, done, _, info = csle_cyborg_env.step(action=a) s_prime = info[constants.ENV_METRICS.STATE] oid = info[constants.ENV_METRICS.OBSERVATION] - if (s, s_prime, a) not in transition_probabilities: - transition_probabilities[(s, s_prime, a)] = 1 + if ",".join([str(s), str(s_prime), str(a)]) not in transition_probabilities: + transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = 1 new_transitions += 1 else: - transition_probabilities[(s, s_prime, a)] = transition_probabilities[(s, s_prime, a)] + 1 - if (s, s_prime, a) not in reward_function: - reward_function[(s, s_prime, a)] = r - if (s_prime, oid) not in observation_probabilities: - observation_probabilities[(s_prime, oid)] = 1 + transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = transition_probabilities[",".join([str(s), str(s_prime), str(a)])] + 1 + if ",".join([str(s), str(s_prime), str(a)]) not in reward_function: + reward_function[",".join([str(s), str(s_prime), str(a)])] = r + if ",".join([str(s_prime), str(oid)]) not in observation_probabilities: + observation_probabilities[",".join([str(s_prime), str(oid)])] = 1 else: - observation_probabilities[(s_prime, oid)] = observation_probabilities[(s_prime, oid)] + 1 + observation_probabilities[",".join([str(s_prime), str(oid)])] = observation_probabilities[",".join([str(s_prime), str(oid)])] + 1 t_count += 1 print(f"new transitions: {new_transitions}") - # if i % save_every == 0: - # model = {} - # model["transitions"] = transition_probabilities - # model["rewards"] = reward_function - # model["episodes"] = i - # model["steps"] = t_count - # model["observations"] = observation_probabilities - # model["initial_state"] = initial_state_distribution - # json_str = json.dumps(model, indent=4, sort_keys=True) - # with io.open(f"/home/kim/cyborg_model_{i}.json", 'w', encoding='utf-8') as f: - # f.write(json_str) + if i % save_every == 0: + model = {} + model["transitions"] = transition_probabilities + model["rewards"] = reward_function + model["episodes"] = i + model["steps"] = t_count + model["observations"] = observation_probabilities + model["initial_state"] = initial_state_distribution + json_str = json.dumps(model, indent=4, sort_keys=True) + with io.open(f"/home/kim/cyborg_model_{i}.json", 'w', encoding='utf-8') as f: + f.write(json_str)