Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX allocating too much memory and remove unnecessary import. #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

# Set flags before imports
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['WANDB_DIR'] = '/tmp'

import argparse
import functools
import os
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
Expand All @@ -14,7 +19,7 @@
import flax.linen as nn

from stable_baselines3.common import type_aliases
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback
from stable_baselines3.common.callbacks import CallbackList, BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped, sync_envs_normalization
from sbx import SAC
Expand All @@ -24,10 +29,6 @@
import gymnasium as gym
from shimmy.registration import DM_CONTROL_SUITE_ENVS


os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['WANDB_DIR'] = '/tmp'

parser = argparse.ArgumentParser()
parser.add_argument("-env", type=str, required=False, default="HumanoidStandup-v4", help="Set Environment.")
parser.add_argument("-algo", type=str, required=True, default='sac', choices=['crossq', 'sac', 'redq', 'droq', 'td3'], help="algorithm to use (essentially a named hyperparameter set for the base SAC algorithm)")
Expand Down