Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VM scheduling with RL #375

Merged
merged 52 commits into from
Sep 26, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
fbcaa10
added part of vm scheduling RL code
Jul 9, 2021
ff7d1ac
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 9, 2021
b7530c1
refined vm env_wrapper code style
Jul 9, 2021
c240210
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 9, 2021
c9bb66f
added DQN
Jul 9, 2021
b690376
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 9, 2021
f3645bd
updated exploration for VM
Jul 11, 2021
0235f89
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 13, 2021
4e1a8b1
added get_experiences func for ac in vm scheduling
Jul 13, 2021
095854c
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 15, 2021
9155297
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 15, 2021
9553344
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 15, 2021
41bf039
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 15, 2021
c52085a
added post_step callback to env wrapper
Jul 15, 2021
f324d91
moved Aiming's tracking and plotting logic into callbacks
Jul 15, 2021
3a5e5b7
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 15, 2021
272d2cc
added eval env wrapper
Jul 16, 2021
7aa6d14
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 19, 2021
c0a0817
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 19, 2021
c51a7b6
renamed AC config variable name for VM
Jul 19, 2021
a76b0e6
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 19, 2021
74fc932
vm scheduling RL code finished
Jul 19, 2021
9e9da67
updated README
Jul 19, 2021
453ec15
fixed various bugs and hard coding for vm_scheduling
Jul 22, 2021
e815b1a
fixed merge conflicts
Jul 22, 2021
2847f6a
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Jul 22, 2021
ed4da3a
uncommented callbacks for VM scheduling
Jul 26, 2021
03b6b78
Minor revision for better code style
lihuoran Aug 5, 2021
1b4d1fc
added part of vm scheduling RL code
Jul 9, 2021
610f681
refined vm env_wrapper code style
Jul 9, 2021
6ba958f
vm scheduling RL code finished
Jul 19, 2021
5cb38f2
added config.py for vm scheduing
Sep 7, 2021
4654110
vm example refactoring
Sep 7, 2021
c51b9e6
fixed bugs in vm_scheduling
Sep 12, 2021
faf0084
fixed conflicts with remote
Sep 12, 2021
0880528
removed unwanted files from cim dir
Sep 12, 2021
41c79b4
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Sep 12, 2021
639c33f
reverted to simple policy manager as default
Sep 12, 2021
9c708ae
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Sep 12, 2021
bc348d7
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Sep 15, 2021
c4e70ec
added part of vm scheduling RL code
Jul 9, 2021
975119e
refined vm env_wrapper code style
Jul 9, 2021
66c32e9
vm scheduling RL code finished
Jul 19, 2021
f14ec81
added config.py for vm scheduing
Sep 7, 2021
c105be0
resolved rebase conflicts
Sep 22, 2021
2982ea4
fixed bugs in vm_scheduling
Sep 12, 2021
db17d70
added get_state and set_state to vm_scheduling policy models
Sep 22, 2021
aa5a357
conflict fix
Sep 22, 2021
b3f0cf3
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Sep 22, 2021
694a6d8
updated README for vm_scheduling with RL
Sep 23, 2021
4481d5f
1. fixed NAN bug in VM scheduling with AC; 2. updated README
Sep 24, 2021
104d2f3
Merge branch 'v0.2_rl_refinement' into v0.2_rl_refinement_vm
Sep 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/rl/cim/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from maro.rl.model import DiscreteACNet, FullyConnectedBlock, OptimOption
from maro.rl.policy.algorithms import ActorCritic, ActorCriticConfig

cim_path = os.path.dirname(os.path.dirname(__file__))
cim_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, cim_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any special purpose for inserting cim_path into the first position of sys.path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to import from files under the cim dir

from env_wrapper import STATE_DIM, env_config

Expand Down
10 changes: 10 additions & 0 deletions examples/rl/vm_scheduling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Virtual Machine Scheduling

Virtual Machine (VM) scheduling is a scenario where reinforcement learning (RL) can help the virtual machine allocator allocate compute resources intelligently. In this folder you can find:
* ``env_wrapper.py``, which contains a function to generate an environment wrapper to interact
with our "agent" (see below);
* ``agent_wrapper.py``, which contains a function to generate an agent wrapper to interact
with the environment wrapper;
* ``policy_index``, which maps policy names to functions that create them; the functions to create DQN and Actor-Critic policies are defined in ``dqn.py`` and ``ac.py``, respectively.

The code for the actual learning workflows (e.g., learner, roll-out worker and trainer) can be found under ``examples/rl/workflows``. The reason for putting it in a separate folder is that these workflows apply to any scenario, so long as the necessary component generators, such as the ones listed above, are provided. See ``README`` under ``examples/rl`` for details. We recommend that you follow this example to write your own scenarios.
11 changes: 11 additions & 0 deletions examples/rl/vm_scheduling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .callbacks import post_collect, post_evaluate
from .env_wrapper import get_env_wrapper, get_eval_env_wrapper
from .policy_index import agent2policy, rl_policy_func_index, update_trigger, warmup

__all__ = [
"agent2policy", "post_collect", "post_evaluate", "get_env_wrapper", "get_eval_env_wrapper",
"rl_policy_func_index", "update_trigger", "warmup"
]
103 changes: 103 additions & 0 deletions examples/rl/vm_scheduling/ac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys

import numpy as np
import torch

from env_wrapper import NUM_PMS, STATE_DIM
from maro.rl.experience import ExperienceStore, UniformSampler
from maro.rl.model import DiscreteACNet, FullyConnectedBlock, OptimOption
from maro.rl.policy.algorithms import ActorCritic, ActorCriticConfig

vm_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, vm_path)

config = {
"model": {
"network": {
"actor": {
"input_dim": STATE_DIM,
"output_dim": NUM_PMS + 1, # action could be any PM or postponement, hence the plus 1
"hidden_dims": [64, 32, 32],
"activation": "leaky_relu",
"softmax": True,
"batch_norm": False,
"head": True
},
"critic": {
"input_dim": STATE_DIM,
"output_dim": 1,
"hidden_dims": [256, 128, 64],
"activation": "leaky_relu",
"softmax": False,
"batch_norm": False,
"head": True
}
},
"optimization": {
"actor": {
"optim_cls": "adam",
"optim_params": {"lr": 0.0001}
},
"critic": {
"optim_cls": "sgd",
"optim_params": {"lr": 0.001}
}
}
},
"algorithm": {
"reward_discount": 0.9,
"train_epochs": 100,
"critic_loss_cls": "mse",
"critic_loss_coeff": 0.1
},
"experience_store": {
"rollout": {"capacity": 10000, "overwrite_type": "rolling"},
"update": {"capacity": 50000, "overwrite_type": "rolling"}
},
"sampler": {
"rollout": {"batch_size": -1, "replace": False},
"update": {"batch_size": 128, "replace": True}
}
}


def get_ac_policy(mode="update"):
class MyACNet(DiscreteACNet):
def forward(self, states, actor: bool = True, critic: bool = True):
if isinstance(states, dict):
states = [states]
inputs = torch.from_numpy(np.asarray([st["model"] for st in states])).to(self.device)
masks = torch.from_numpy(np.asarray([st["mask"] for st in states])).to(self.device)
if len(inputs.shape) == 1:
inputs = inputs.unsqueeze(dim=0)
return (
self.component["actor"](inputs) * masks if actor else None,
self.component["critic"](inputs) if critic else None
)

ac_net = MyACNet(
component={
"actor": FullyConnectedBlock(**config["model"]["network"]["actor"]),
"critic": FullyConnectedBlock(**config["model"]["network"]["critic"])
},
optim_option={
"actor": OptimOption(**config["model"]["optimization"]["actor"]),
"critic": OptimOption(**config["model"]["optimization"]["critic"])
} if mode != "inference" else None
)
if mode == "update":
exp_store = ExperienceStore(**config["experience_store"]["update"])
exp_sampler_kwargs = config["sampler"]["update"]
else:
exp_store = ExperienceStore(**config["experience_store"]["rollout" if mode == "inference" else "update"])
exp_sampler_kwargs = config["sampler"]["rollout" if mode == "inference" else "update"]

return ActorCritic(
ac_net, ActorCriticConfig(**config["algorithm"]), exp_store,
experience_sampler_cls=UniformSampler,
experience_sampler_kwargs=exp_sampler_kwargs
)
88 changes: 88 additions & 0 deletions examples/rl/vm_scheduling/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import time
from os import makedirs
from os.path import dirname, join, realpath

import matplotlib.pyplot as plt

from maro.utils import Logger

timestamp = str(time.time())

log_dir = join(dirname(realpath(__file__)), "log", timestamp)
makedirs(log_dir, exist_ok=True)

plt_path = join(dirname(realpath(__file__)), "plots", timestamp)
makedirs(plt_path, exist_ok=True)

simulation_logger = Logger("SIMUALTION", dump_folder=log_dir)


def post_collect(trackers, ep, segment):
# print the env metric from each rollout worker
for tracker in trackers:
simulation_logger.info(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}")

# print the average env metric
if len(trackers) > 1:
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
simulation_logger.info(f"average env metric (episode {ep}, segment {segment}): {avg_metric}")


def post_evaluate(trackers, ep):
# print the env metric from each rollout worker
for tracker in trackers:
simulation_logger.info(f"env summary (evaluation episode {ep}): {tracker['env_metric']}")

# print the average env metric
if len(trackers) > 1:
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
simulation_logger.info(f"average env metric (evaluation episode {ep}): {avg_metric}")

for i, tracker in enumerate(trackers):
core_requirement = tracker["vm_core_requirement"]
action_sequence = tracker["action_sequence"]
# plot action sequence
fig = plt.figure(figsize=(40, 32))
ax = fig.add_subplot(1, 1, 1)
ax.plot(action_sequence)
fig.savefig(f"{plt_path}/action_sequence_{ep}")
plt.cla()
plt.close("all")

# plot with legal action mask
fig = plt.figure(figsize=(40, 32))
for idx, key in enumerate(core_requirement.keys()):
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
for i in range(len(core_requirement[key])):
if i == 0:
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1], label=str(key))
ax.legend()
else:
ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1])

fig.savefig(f"{plt_path}/values_with_legal_action_{ep}")

plt.cla()
plt.close("all")

# plot without legal actin mask
fig = plt.figure(figsize=(40, 32))

for idx, key in enumerate(core_requirement.keys()):
ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1)
for i in range(len(core_requirement[key])):
if i == 0:
ax.plot(core_requirement[key][i][0], label=str(key))
ax.legend()
else:
ax.plot(core_requirement[key][i][0])

fig.savefig(f"{plt_path}/values_without_legal_action_{ep}")

plt.cla()
plt.close("all")
123 changes: 123 additions & 0 deletions examples/rl/vm_scheduling/dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys

import numpy as np
import torch

from env_wrapper import NUM_PMS, STATE_DIM
from maro.rl.experience import ExperienceStore, UniformSampler
from maro.rl.exploration import DiscreteSpaceExploration, MultiPhaseLinearExplorationScheduler
from maro.rl.model import DiscreteQNet, FullyConnectedBlock, OptimOption
from maro.rl.policy.algorithms import DQN, DQNConfig

vm_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, vm_path)

config = {
"model": {
"network": {
"input_dim": STATE_DIM,
"hidden_dims": [64, 128, 256],
"output_dim": NUM_PMS + 1, # action could be any PM or postponement, hence the plus 1
"activation": "leaky_relu",
"softmax": False,
"batch_norm": False,
"skip_connection": False,
"head": True,
"dropout_p": 0.0
},
"optimization": {
"optim_cls": "sgd",
"optim_params": {"lr": 0.0005},
"scheduler_cls": "cosine_annealing_warm_restarts",
"scheduler_params": {"T_0": 500, "T_mult": 2}
}
},
"algorithm": {
"reward_discount": 0.9,
"update_target_every": 5,
"train_epochs": 100,
"soft_update_coeff": 0.1,
"double": False
},
"experience_store": {
"rollout": {"capacity": 10000, "overwrite_type": "rolling"},
"update": {"capacity": 50000, "overwrite_type": "rolling"}
},
"sampler": {
"rollout": {"batch_size": -1, "replace": False},
"update": {"batch_size": 256, "replace": True}
},
"exploration": {
"last_ep": 400,
"initial_value": 0.4,
"final_value": 0.0,
"splits": [(100, 0.32)]
}
}


class MyQNet(DiscreteQNet):
def __init__(self, component, optim_option, device: str = None):
super().__init__(component, optim_option=optim_option, device=device)
for mdl in self.modules():
if isinstance(mdl, torch.nn.Linear):
torch.nn.init.xavier_uniform_(mdl.weight, gain=torch.nn.init.calculate_gain('leaky_relu'))

def forward(self, states):
if isinstance(states, dict):
states = [states]
inputs = torch.from_numpy(np.asarray([st["model"] for st in states])).to(self.device)
masks = torch.from_numpy(np.asarray([st["mask"] for st in states])).to(self.device)
if len(inputs.shape) == 1:
inputs = inputs.unsqueeze(dim=0)
q_for_all_actions = self.component(inputs)
return q_for_all_actions + (masks - 1) * 1e8


class MaskedEpsilonGreedy(DiscreteSpaceExploration):
def __init__(self, epsilon: float = .0):
super().__init__()
self.epsilon = epsilon

def __call__(self, action, state):
if isinstance(state, dict):
state = [state]
mask = [st["mask"] for st in state]
return np.array([
act if np.random.random() > self.epsilon else np.random.choice(np.where(mk == 1)[0])
for act, mk in zip(action, mask)
])


def get_dqn_policy(mode="update"):
assert mode in {"inference", "update", "inference-update"}
q_net = MyQNet(
FullyConnectedBlock(**config["model"]["network"]),
optim_option=OptimOption(**config["model"]["optimization"]) if mode != "inference" else None
)

if mode == "update":
exp_store = ExperienceStore(**config["experience_store"]["update"])
exploration = None
exp_sampler_kwargs = config["sampler"]["update"]
else:
exp_store = ExperienceStore(**config["experience_store"]["rollout"])
exploration = MaskedEpsilonGreedy()
exploration.register_schedule(
scheduler_cls=MultiPhaseLinearExplorationScheduler,
param_name="epsilon",
**config["exploration"]
)
exp_store = ExperienceStore(**config["experience_store"]["rollout" if mode == "inference" else "update"])
exp_sampler_kwargs = config["sampler"]["rollout" if mode == "inference" else "update"]

return DQN(
q_net, DQNConfig(**config["algorithm"]), exp_store,
experience_sampler_cls=UniformSampler,
experience_sampler_kwargs=exp_sampler_kwargs,
exploration=exploration
)
Loading