forked from danijar/dreamerv3
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexample.py
60 lines (53 loc) · 2.13 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def main():
import warnings
import dreamerv3
from dreamerv3 import embodied
warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')
# See configs.yaml for all options.
config = embodied.Config(dreamerv3.configs['defaults'])
config = config.update(dreamerv3.configs['medium'])
config = config.update({
'logdir': '~/logdir/run1',
'run.train_ratio': 64,
'run.log_every': 30, # Seconds
'batch_size': 16,
'jax.prealloc': False,
'encoder.mlp_keys': '$^',
'decoder.mlp_keys': '$^',
'encoder.cnn_keys': 'image',
'decoder.cnn_keys': 'image',
'jax.platform': 'cpu',
})
config = embodied.Flags(config).parse()
logdir = embodied.Path(config.logdir)
step = embodied.Counter()
logger = embodied.Logger(step, [
embodied.logger.TerminalOutput(),
embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
embodied.logger.TensorBoardOutput(logdir),
# embodied.logger.WandBOutput(wandb_init_kwargs={
# 'project': 'dreamerv3-compat',
# 'name': logdir.name,
# 'config': dict(config),
# }),
# embodied.logger.MLFlowOutput(logdir.name),
])
import crafter
# from gymnasium.wrappers.compatibility import EnvCompatibility
from gym.wrappers.compatibility import EnvCompatibility
from dreamerv3.embodied.envs import from_gym
env = crafter.Env() # Replace this with your Gym env.
env = EnvCompatibility(env, render_mode='rgb_array') # Apply EnvCompatibility wrapper because crafter is still at gym==0.19.0 API
env = from_gym.FromGym(env, obs_key='image') # Or obs_key='vector'.
env = dreamerv3.wrap_env(env, config)
env = embodied.BatchEnv([env], parallel=False)
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
replay = embodied.replay.Uniform(
config.batch_length, config.replay_size, logdir / 'replay')
args = embodied.Config(
**config.run, logdir=config.logdir,
batch_steps=config.batch_size * config.batch_length) # type: ignore
embodied.run.train(agent, env, replay, logger, args)
# embodied.run.eval_only(agent, env, logger, args)
if __name__ == '__main__':
main()