diff --git a/reinforcement-learning/breakout/DQN.py b/reinforcement-learning/breakout/DQN.py index fff4ff0..ee5c6b1 100644 --- a/reinforcement-learning/breakout/DQN.py +++ b/reinforcement-learning/breakout/DQN.py @@ -275,7 +275,7 @@ def preprocess(self, state): return state - def observe(self, environment, states, *args): # noqa + def observe(self, environment, states, skip=None): """ Observe the environment for n frames. @@ -285,7 +285,7 @@ def observe(self, environment, states, *args): # noqa The environment to observe. states : torch.Tensor The states of the environment from the previous step. - args + skip : int, optional To be compatible with the other DQN agents. Added here instead of using ABC. Returns @@ -299,6 +299,8 @@ def observe(self, environment, states, *args): # noqa done : bool Whether the game is terminated. """ + print("Warning: `skip` is not used in `VisionDeepQ.observe`.") if skip is not None else None + action = self.action(states) done = False diff --git a/reinforcement-learning/utilities/visualisation/movie.py b/reinforcement-learning/utilities/visualisation/movie.py index ae05770..9e45626 100644 --- a/reinforcement-learning/utilities/visualisation/movie.py +++ b/reinforcement-learning/utilities/visualisation/movie.py @@ -1,8 +1,10 @@ +"""Create a movie of an agent interacting with an environment.""" + import cv2 import torch -def create_movie(environment, agent, path, fps=60): +def create_movie(environment, agent, path="./live-preview.gif", skip=4, fps=50): """Created by Mistral Large.""" initial = agent.preprocess(environment.reset()[0]) try: @@ -10,20 +12,14 @@ def create_movie(environment, agent, path, fps=60): except AttributeError: states = initial - try: - done = False + done = False - # Get the dimensions of the first image - height, width, channels = environment.render().shape + height, width, _ = environment.render().shape + fourcc = cv2.VideoWriter_fourcc(*"MJPG") # noqa + movie = cv2.VideoWriter(path, fourcc, fps, (width, height)) - # Create the VideoWriter object - fourcc = cv2.VideoWriter_fourcc(*"MJPG") # You can change the codec if needed - video_writer = cv2.VideoWriter(path, fourcc, fps, (width, height)) - while not done: - _, states, _, done = agent.observe(environment, states) - video_writer.write(environment.render()) - except Exception as e: - print(f"Error during image generation or writing: {e}") - return + while not done: + _, states, _, done = agent.observe(environment, states, skip) + movie.write(environment.render()) cv2.destroyAllWindows() diff --git a/requirements.txt b/requirements.txt index 85038d8..c382e43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,12 @@ torch numpy pandas scipy + +# For visualisation +# *-----------------------------------------* matplotlib imageio +opencv-python # Library for utilising Apple M chips # *-----------------------------------------*