diff --git a/train.py b/train.py index 789a9ab..df2b889 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,11 @@ +import os + +# Set flags before imports +os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' +os.environ['WANDB_DIR'] = '/tmp' + import argparse import functools -import os import time from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -14,7 +19,7 @@ import flax.linen as nn from stable_baselines3.common import type_aliases -from stable_baselines3.common.callbacks import EvalCallback, CallbackList, BaseCallback +from stable_baselines3.common.callbacks import CallbackList, BaseCallback from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped, sync_envs_normalization from sbx import SAC @@ -24,10 +29,6 @@ import gymnasium as gym from shimmy.registration import DM_CONTROL_SUITE_ENVS - -os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' -os.environ['WANDB_DIR'] = '/tmp' - parser = argparse.ArgumentParser() parser.add_argument("-env", type=str, required=False, default="HumanoidStandup-v4", help="Set Environment.") parser.add_argument("-algo", type=str, required=True, default='sac', choices=['crossq', 'sac', 'redq', 'droq', 'td3'], help="algorithm to use (essentially a named hyperparameter set for the base SAC algorithm)")