Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: uniformize imports #356

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/demo_agents/demo_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time

import gymnasium as gym

from rlberry.agents.torch.sac import SACAgent
from rlberry.envs import Pendulum
from rlberry.manager import AgentManager
Expand Down
4 changes: 2 additions & 2 deletions examples/demo_agents/video_plot_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_a2c.jpg'

from rlberry.agents.torch import A2CAgent
from rlberry.envs.benchmarks.ball_exploration import PBall2D
from gymnasium.wrappers import TimeLimit

from rlberry.agents.torch import A2CAgent
from rlberry.envs.benchmarks.ball_exploration import PBall2D

env = PBall2D()
env = TimeLimit(env, max_episode_steps=256)
Expand Down
11 changes: 5 additions & 6 deletions examples/demo_agents/video_plot_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@

# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg'

from rlberry.envs import gym_make
import os
import shutil

from gymnasium.wrappers.record_video import RecordVideo
from torch.utils.tensorboard import SummaryWriter

from rlberry.agents.torch.dqn import DQNAgent
from rlberry.envs import gym_make
from rlberry.utils.logging import configure_logging

from gymnasium.wrappers.record_video import RecordVideo
import shutil
import os


configure_logging(level="INFO")

env = gym_make("CartPole-v1", render_mode="rgb_array")
Expand Down
11 changes: 5 additions & 6 deletions examples/demo_agents/video_plot_mdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@

# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg'

from rlberry.envs import gym_make
import os
import shutil

from gymnasium.wrappers.record_video import RecordVideo
from torch.utils.tensorboard import SummaryWriter

from rlberry.agents.torch.dqn import MunchausenDQNAgent
from rlberry.envs import gym_make
from rlberry.utils.logging import configure_logging

from gymnasium.wrappers.record_video import RecordVideo
import shutil
import os


configure_logging(level="INFO")

env = gym_make("CartPole-v1", render_mode="rgb_array")
Expand Down
1 change: 0 additions & 1 deletion examples/demo_agents/video_plot_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from rlberry.agents.torch import PPOAgent
from rlberry.envs.benchmarks.ball_exploration import PBall2D


env = PBall2D()
n_steps = 3e3

Expand Down
2 changes: 1 addition & 1 deletion examples/demo_agents/video_plot_rs_kernel_ucbvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rs_kernel_ucbvi.jpg'

from rlberry.envs import Acrobot
from rlberry.agents import RSKernelUCBVIAgent
from rlberry.envs import Acrobot
from rlberry.wrappers import RescaleRewardWrapper

env = Acrobot()
Expand Down
8 changes: 4 additions & 4 deletions examples/demo_bandits/plot_TS_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
"""

import numpy as np
from rlberry.envs.bandits import BernoulliBandit, NormalBandit

from rlberry.agents.bandits import (
IndexAgent,
TSAgent,
makeBoundedUCBIndex,
makeSubgaussianUCBIndex,
makeBetaPrior,
makeBoundedUCBIndex,
makeGaussianPrior,
makeSubgaussianUCBIndex,
)
from rlberry.envs.bandits import BernoulliBandit, NormalBandit
from rlberry.manager import ExperimentManager, plot_writer_data
from rlberry.wrappers import WriterWrapper


# Bernoulli

# Agents definition
Expand Down
9 changes: 5 additions & 4 deletions examples/demo_bandits/plot_compare_index_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
This script Compare several bandits agents and as a sub-product also shows
how to use subplots in with `plot_writer_data`
"""
import numpy as np
import matplotlib.pyplot as plt
from rlberry.envs.bandits import BernoulliBandit
from rlberry.manager import ExperimentManager, plot_writer_data
from rlberry.wrappers import WriterWrapper
import numpy as np

from rlberry.agents.bandits import (
IndexAgent,
RandomizedAgent,
Expand All @@ -22,6 +20,9 @@
makeETCIndex,
makeEXP3Index,
)
from rlberry.envs.bandits import BernoulliBandit
from rlberry.manager import ExperimentManager, plot_writer_data
from rlberry.wrappers import WriterWrapper

# Agents definition
# sphinx_gallery_thumbnail_number = 2
Expand Down
6 changes: 3 additions & 3 deletions examples/demo_bandits/plot_exp3_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
"""

import numpy as np
from rlberry.envs.bandits import AdversarialBandit

from rlberry.agents.bandits import (
RandomizedAgent,
TSAgent,
makeEXP3Index,
makeBetaPrior,
makeEXP3Index,
)
from rlberry.envs.bandits import AdversarialBandit
from rlberry.manager import ExperimentManager, plot_writer_data
from rlberry.wrappers import WriterWrapper


# Agents definition


Expand Down
15 changes: 6 additions & 9 deletions examples/demo_bandits/plot_mirror_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
The code is in three parts: definition of environment, definition of agent,
and finally definition of the experiment.
"""
import matplotlib.pyplot as plt
import numpy as np

from rlberry.manager import ExperimentManager, read_writer_data
from rlberry.envs.interface import Model
from rlberry.agents.bandits import BanditWithSimplePolicy
from rlberry.wrappers import WriterWrapper
import rlberry.spaces as spaces

import requests
import matplotlib.pyplot as plt


import rlberry
import rlberry.spaces as spaces
from rlberry.agents.bandits import BanditWithSimplePolicy
from rlberry.envs.interface import Model
from rlberry.manager import ExperimentManager, read_writer_data
from rlberry.wrappers import WriterWrapper

logger = rlberry.logger

Expand Down
6 changes: 3 additions & 3 deletions examples/demo_bandits/plot_ucb_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
This script shows how to define a bandit environment and an UCB Index-based algorithm.
"""

import matplotlib.pyplot as plt
import numpy as np
from rlberry.envs.bandits import NormalBandit

from rlberry.agents.bandits import IndexAgent, makeSubgaussianUCBIndex
from rlberry.envs.bandits import NormalBandit
from rlberry.manager import ExperimentManager, plot_writer_data
import matplotlib.pyplot as plt
from rlberry.wrappers import WriterWrapper


# Agents definition


Expand Down
13 changes: 7 additions & 6 deletions examples/demo_env/example_atari_atlantis_vectorized_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/example_plot_atari_atlantis_vectorized_ppo.jpg'


from rlberry.manager import ExperimentManager
import os
import shutil
from datetime import datetime
from rlberry.agents.torch import PPOAgent

from gymnasium.wrappers.record_video import RecordVideo
import shutil
import os
from rlberry.envs.gym_make import atari_make
from rlberry.agents.torch.utils.training import model_factory_from_env

from rlberry.agents.torch import PPOAgent
from rlberry.agents.torch.utils.training import model_factory_from_env
from rlberry.envs.gym_make import atari_make
from rlberry.manager import ExperimentManager

initial_time = datetime.now()
print("-------- init agent --------")
Expand Down
13 changes: 7 additions & 6 deletions examples/demo_env/example_atari_breakout_vectorized_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/example_plot_atari_breakout_vectorized_ppo.jpg'


from rlberry.manager import ExperimentManager
import os
import shutil
from datetime import datetime
from rlberry.agents.torch import PPOAgent

from gymnasium.wrappers.record_video import RecordVideo
import shutil
import os
from rlberry.envs.gym_make import atari_make
from rlberry.agents.torch.utils.training import model_factory_from_env

from rlberry.agents.torch import PPOAgent
from rlberry.agents.torch.utils.training import model_factory_from_env
from rlberry.envs.gym_make import atari_make
from rlberry.manager import ExperimentManager

initial_time = datetime.now()
print("-------- init agent --------")
Expand Down
2 changes: 1 addition & 1 deletion examples/demo_env/video_plot_acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_acrobot.jpg'

from rlberry.envs import Acrobot
from rlberry.agents import RSUCBVIAgent
from rlberry.envs import Acrobot
from rlberry.wrappers import RescaleRewardWrapper

env = Acrobot()
Expand Down
3 changes: 2 additions & 1 deletion examples/demo_env/video_plot_apple_gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
:width: 600

"""
from rlberry.agents.dynprog import ValueIterationAgent

# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_apple_gold.jpg'
from rlberry.envs.benchmarks.grid_exploration.apple_gold import AppleGold
from rlberry.agents.dynprog import ValueIterationAgent

env = AppleGold(reward_free=False, array_observation=False)

Expand Down
11 changes: 6 additions & 5 deletions examples/demo_env/video_plot_atari_freeway.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_atari_freeway.jpg'


from rlberry.manager import ExperimentManager
import os
import shutil
from datetime import datetime
from rlberry.agents.torch.dqn.dqn import DQNAgent

from gymnasium.wrappers.record_video import RecordVideo
import shutil
import os
from rlberry.envs.gym_make import atari_make

from rlberry.agents.torch.dqn.dqn import DQNAgent
from rlberry.envs.gym_make import atari_make
from rlberry.manager import ExperimentManager

initial_time = datetime.now()
print("-------- init agent --------")
Expand Down
1 change: 0 additions & 1 deletion examples/demo_env/video_plot_gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from rlberry.agents.dynprog import ValueIterationAgent
from rlberry.envs.finite import GridWorld


env = GridWorld(7, 10, walls=((2, 2), (3, 3)))

agent = ValueIterationAgent(env, gamma=0.95)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_old_gym_acrobot.jpg'


from rlberry.wrappers.tests.old_env.old_acrobot import Old_Acrobot
from rlberry.agents import RSUCBVIAgent
from rlberry.wrappers import RescaleRewardWrapper
from rlberry.wrappers.gym_utils import OldGymCompatibilityWrapper
from rlberry.wrappers.tests.old_env.old_acrobot import Old_Acrobot

env = Old_Acrobot()
env = OldGymCompatibilityWrapper(env)
Expand Down
1 change: 1 addition & 0 deletions examples/demo_env/video_plot_pball.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_pball.jpg'

import numpy as np

from rlberry.envs.benchmarks.ball_exploration import PBall2D

p = 5
Expand Down
2 changes: 1 addition & 1 deletion examples/demo_env/video_plot_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rooms.jpg'

from rlberry.envs.benchmarks.grid_exploration.nroom import NRoom
from rlberry.agents.dynprog import ValueIterationAgent
from rlberry.envs.benchmarks.grid_exploration.nroom import NRoom

env = NRoom(
nrooms=9,
Expand Down
5 changes: 3 additions & 2 deletions examples/demo_env/video_plot_springcartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_springcartpole.jpg'

from rlberry.envs.classic_control import SpringCartPole
from rlberry.agents.torch import DQNAgent
from gymnasium.wrappers.time_limit import TimeLimit

from rlberry.agents.torch import DQNAgent
from rlberry.envs.classic_control import SpringCartPole

model_configs = {
"type": "MultiLayerPerceptron",
"layer_sizes": (256, 256),
Expand Down
4 changes: 2 additions & 2 deletions examples/demo_env/video_plot_twinrooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_twinrooms.jpg'

from rlberry.envs.benchmarks.generalization.twinrooms import TwinRooms
from rlberry.agents.mbqvi import MBQVIAgent
from rlberry.wrappers.discretize_state import DiscretizeStateWrapper
from rlberry.envs.benchmarks.generalization.twinrooms import TwinRooms
from rlberry.seeding import Seeder
from rlberry.wrappers.discretize_state import DiscretizeStateWrapper

seeder = Seeder(123)

Expand Down
4 changes: 1 addition & 3 deletions examples/demo_experiment/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
$ python examples/demo_examples/demo_experiment/run.py
"""

from rlberry.experiment import load_experiment_results
from rlberry.experiment import experiment_generator
from rlberry.experiment import experiment_generator, load_experiment_results
from rlberry.manager.multiple_managers import MultipleManagers


if __name__ == "__main__":
multimanagers = MultipleManagers(parallelization="thread")

Expand Down
6 changes: 3 additions & 3 deletions examples/demo_network/run_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Demo: run_client
=====================
"""
from rlberry.network.client import BerryClient
from rlberry.network import interface
from rlberry.network.interface import Message, ResourceRequest
import numpy as np

from rlberry.network import interface
from rlberry.network.client import BerryClient
from rlberry.network.interface import Message, ResourceRequest

port = int(input("Select server port: "))
client = BerryClient(port=port)
Expand Down
9 changes: 3 additions & 6 deletions examples/demo_network/run_remote_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
Demo: run_remote_manager
=====================
"""
from rlberry.envs.gym_make import gym_make
from rlberry.network.client import BerryClient
from rlberry.network.interface import ResourceRequest

from rlberry.agents.torch import REINFORCEAgent

from rlberry.envs.gym_make import gym_make
from rlberry.manager import ExperimentManager, MultipleManagers, RemoteExperimentManager
from rlberry.manager.evaluation import evaluate_agents, plot_writer_data

from rlberry.network.client import BerryClient
from rlberry.network.interface import ResourceRequest

if __name__ == "__main__":
port = int(input("Select server port: "))
Expand Down
Loading