diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index c5203f4c7..be1040a95 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -28,6 +28,10 @@ def is_colab() -> bool: return "google.colab" in sys.modules +def is_notebook() -> bool: + return "ipykernel" in sys.modules + + # In a notebook, force the Matplotlib backend to ngAgg in order for figures to update # every time render is called for environments that use Matplotlib # for rendering. Without this, only the last render result is shown per figure. @@ -39,8 +43,10 @@ def is_colab() -> bool: if is_colab(): backend = "inline" - else: + elif is_notebook(): backend = "notebook" + else: + backend = "" IPython.get_ipython().run_line_magic("matplotlib", backend) except ImportError as exc: