diff --git a/beginner_source/nn_tutorial.py b/beginner_source/nn_tutorial.py index bc32131b93..7ee7df3b43 100644 --- a/beginner_source/nn_tutorial.py +++ b/beginner_source/nn_tutorial.py @@ -75,6 +75,11 @@ import numpy as np pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray") +# ``pyplot.show()`` only if not on Colab +try: + import google.colab +except ImportError: + pyplot.show() print(x_train.shape) ###############################################################################