Skip to content

Commit

Permalink
Merge branch 'main' into add_coding_ddpg
Browse files Browse the repository at this point in the history
  • Loading branch information
Svetlana Karslioglu authored Jun 13, 2023
2 parents 1263c60 + ea0a11c commit f4a5e4b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions intermediate_source/mario_rl_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def act(self, state):
class Mario(Mario): # subclassing for continuity
def __init__(self, state_dim, action_dim, save_dir):
super().__init__(state_dim, action_dim, save_dir)
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000))
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device("cpu")))
self.batch_size = 32

def cache(self, state, next_state, action, reward, done):
Expand All @@ -369,11 +369,11 @@ def first_if_tuple(x):
state = first_if_tuple(state).__array__()
next_state = first_if_tuple(next_state).__array__()

state = torch.tensor(state, device=self.device)
next_state = torch.tensor(next_state, device=self.device)
action = torch.tensor([action], device=self.device)
reward = torch.tensor([reward], device=self.device)
done = torch.tensor([done], device=self.device)
state = torch.tensor(state)
next_state = torch.tensor(next_state)
action = torch.tensor([action])
reward = torch.tensor([reward])
done = torch.tensor([done])

# self.memory.append((state, next_state, action, reward, done,))
self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))
Expand All @@ -382,7 +382,7 @@ def recall(self):
"""
Retrieve a batch of experiences from memory
"""
batch = self.memory.sample(self.batch_size)
batch = self.memory.sample(self.batch_size).to(self.device)
state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

Expand Down

0 comments on commit f4a5e4b

Please sign in to comment.