Skip to content

Commit

Permalink
Fix axes indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Mar 10, 2023
1 parent e5b73fc commit 8bdad76
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torchgeo/datasets/ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
fig, axs = plt.subplots(ncols=self.seasons, figsize=(4, 4))
fig, axes = plt.subplots(ncols=self.seasons, figsize=(4, 4))
if self.seasons == 1:
axes = [axes]

for i in range(self.seasons):
image = sample["image"][i * len(self.bands) : (i + 1) * len(self.bands)]
Expand Down Expand Up @@ -239,8 +241,8 @@ def plot(
image = image[[3, 2, 1]].permute(1, 2, 0)
image = torch.clamp(image / 10000, min=0, max=1)

axs[i].imshow(image)
axs[i].axis("off")
axes[i].imshow(image)
axes[i].axis("off")

if suptitle is not None:
plt.suptitle(suptitle)
Expand Down

0 comments on commit 8bdad76

Please sign in to comment.