From 116c23feb686eea00e74a4756389d70352443ad3 Mon Sep 17 00:00:00 2001 From: Edouard Leurent Date: Tue, 12 Nov 2024 23:37:54 +0000 Subject: [PATCH] Fix https://github.com/Farama-Foundation/HighwayEnv/issues/631 --- highway_env/envs/common/abstract.py | 5 +++-- scripts/sb3_highway_dqn.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/highway_env/envs/common/abstract.py b/highway_env/envs/common/abstract.py index 5ef8ad697..cb2723fa0 100644 --- a/highway_env/envs/common/abstract.py +++ b/highway_env/envs/common/abstract.py @@ -325,6 +325,7 @@ def get_available_actions(self) -> list[int]: def set_record_video_wrapper(self, wrapper: RecordVideo): self._record_video_wrapper = wrapper self.update_metadata() + self._record_video_wrapper.frames_per_sec = self.metadata["render_fps"] def _automatic_rendering(self) -> None: """ @@ -334,8 +335,8 @@ def _automatic_rendering(self) -> None: If a RecordVideo wrapper has been set, use it to capture intermediate frames. """ if self.viewer is not None and self.enable_auto_render: - if self._record_video_wrapper and self._record_video_wrapper.video_recorder: - self._record_video_wrapper.video_recorder.capture_frame() + if self._record_video_wrapper: + self._record_video_wrapper._capture_frame() else: self.render() diff --git a/scripts/sb3_highway_dqn.py b/scripts/sb3_highway_dqn.py index 1cec4d290..05604d57e 100644 --- a/scripts/sb3_highway_dqn.py +++ b/scripts/sb3_highway_dqn.py @@ -40,8 +40,8 @@ env = RecordVideo( env, video_folder="highway_dqn/videos", episode_trigger=lambda e: True ) - env.unwrapped.set_record_video_wrapper(env) env.unwrapped.config["simulation_frequency"] = 15 # Higher FPS for rendering + env.unwrapped.set_record_video_wrapper(env) for videos in range(10): done = truncated = False