From ea0a11cadcb0f7a077b2789d8763910c0773a304 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 13 Jun 2023 16:01:29 +0100 Subject: [PATCH] init (#2464) --- intermediate_source/mario_rl_tutorial.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/intermediate_source/mario_rl_tutorial.py b/intermediate_source/mario_rl_tutorial.py index eb46feb2ad..e4bfd86916 100755 --- a/intermediate_source/mario_rl_tutorial.py +++ b/intermediate_source/mario_rl_tutorial.py @@ -350,7 +350,7 @@ def act(self, state): class Mario(Mario): # subclassing for continuity def __init__(self, state_dim, action_dim, save_dir): super().__init__(state_dim, action_dim, save_dir) - self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000)) + self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device("cpu"))) self.batch_size = 32 def cache(self, state, next_state, action, reward, done): @@ -369,11 +369,11 @@ def first_if_tuple(x): state = first_if_tuple(state).__array__() next_state = first_if_tuple(next_state).__array__() - state = torch.tensor(state, device=self.device) - next_state = torch.tensor(next_state, device=self.device) - action = torch.tensor([action], device=self.device) - reward = torch.tensor([reward], device=self.device) - done = torch.tensor([done], device=self.device) + state = torch.tensor(state) + next_state = torch.tensor(next_state) + action = torch.tensor([action]) + reward = torch.tensor([reward]) + done = torch.tensor([done]) # self.memory.append((state, next_state, action, reward, done,)) self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[])) @@ -382,7 +382,7 @@ def recall(self): """ Retrieve a batch of experiences from memory """ - batch = self.memory.sample(self.batch_size) + batch = self.memory.sample(self.batch_size).to(self.device) state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done")) return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()