Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 9, 2024
1 parent 067b560 commit c010e39
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy.eval())
collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def make_environment(cfg):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(cfg, train_env, actor_model_explore, device):
"""Make collector."""
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down

0 comments on commit c010e39

Please sign in to comment.