diff --git a/gym/decision_transformer/models/decision_transformer.py b/gym/decision_transformer/models/decision_transformer.py index d4985a56..f76f8f3b 100644 --- a/gym/decision_transformer/models/decision_transformer.py +++ b/gym/decision_transformer/models/decision_transformer.py @@ -90,7 +90,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ x = transformer_outputs['last_hidden_state'] # reshape x so that the second dimension corresponds to the original - # returns (0), actions (1), or states (2); i.e. x[:,1,t] is the token for a_t + # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) # get predictions