Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
threewisemonkeys-as committed Oct 29, 2020
1 parent 8d5a8b6 commit 8030b2a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
4 changes: 3 additions & 1 deletion examples/distributed/offpolicy_distributed_primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def collect_experience(agent, parameter_server, experience_server, learner):


class MyTrainer(DistributedTrainer):
def __init__(self, agent, train_steps, batch_size, init_buffer_size, log_interval=200):
def __init__(
self, agent, train_steps, batch_size, init_buffer_size, log_interval=200
):
super(MyTrainer, self).__init__(agent)
self.train_steps = train_steps
self.batch_size = batch_size
Expand Down
37 changes: 16 additions & 21 deletions examples/distributed/onpolicy_distributed_primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,8 @@ def get_advantages_returns(rewards, dones, values, gamma=0.99, gae_lambda=1):
else:
next_non_terminal = 1.0 - dones[step + 1]
next_value = values[step + 1]
delta = (
rewards[step]
+ gamma * next_value * next_non_terminal
- values[step]
)
last_gae_lam = (
delta
+ gamma * gae_lambda * next_non_terminal * last_gae_lam
)
delta = rewards[step] + gamma * next_value * next_non_terminal - values[step]
last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
advantages[step] = last_gae_lam
returns = advantages + values
return advantages.detach(), returns.detach()
Expand All @@ -66,7 +59,9 @@ def unroll_trajs(trajectories):


class A2C:
def __init__(self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5):
def __init__(
self, env, policy, value, policy_optim, value_optim, grad_norm_limit=0.5
):
self.env = env
self.policy = policy
self.value = value
Expand Down Expand Up @@ -102,16 +97,14 @@ def update_params(self, trajectories):
self.value_optim.step()

def get_weights(self):
return {
"policy": self.policy.state_dict(),
"value": self.value.state_dict()
}
return {"policy": self.policy.state_dict(), "value": self.value.state_dict()}

def load_weights(self, weights):
self.policy.load_state_dict(weights["policy"])
self.value.load_state_dict(weights["value"])

class Trajectory():

class Trajectory:
def __init__(self):
self.states = []
self.actions = []
Expand All @@ -129,7 +122,8 @@ def add(self, state, action, reward, done):
def __len__(self):
return self.__len

class TrajBuffer():

class TrajBuffer:
def __init__(self, size):
if size <= 0:
raise ValueError("Size of buffer must be larger than 0")
Expand All @@ -139,20 +133,21 @@ def __init__(self, size):

def is_full(self):
return self._full

def push(self, traj):
if not self.is_full():
self._memory.append(traj)
if len(self._memory) >= self._size:
self._full = True

def get(self, clear=True):
out = copy.deepcopy(self._memory)
out = copy.deepcopy(self._memory)
if clear:
self._memory = []
self._full = False
return out


def collect_experience(agent, parameter_server, experience_server, learner):
current_step = -1
while not learner.is_completed():
Expand All @@ -171,7 +166,7 @@ def collect_experience(agent, parameter_server, experience_server, learner):
traj.add(obs, action, reward, done)
obs = next_obs
if done:
break
break
experience_server.push(traj)
print("pushed a traj")

Expand Down Expand Up @@ -206,7 +201,7 @@ def train(self, parameter_server, experience_server):


master = Master(
world_size=N_ACTORS+4,
world_size=N_ACTORS + 4,
address="localhost",
port=29500,
proc_start_method="fork",
Expand Down

0 comments on commit 8030b2a

Please sign in to comment.