diff --git a/openrl/modules/networks/MAT_network.py b/openrl/modules/networks/MAT_network.py index e1f79e41..48958c28 100644 --- a/openrl/modules/networks/MAT_network.py +++ b/openrl/modules/networks/MAT_network.py @@ -420,7 +420,7 @@ def get_actions( ): obs = obs.reshape(-1, self.n_agent, self.obs_dim) if action_masks is not None: - action_masks = action_masks.reshape(-1, self.num_agents, self.action_dim) + action_masks = action_masks.reshape(-1, self.n_agent, self.action_dim) # state unused ori_shape = np.shape(obs)