Skip to content

Commit

Permalink
Ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
robmarkcole committed Jan 2, 2025
1 parent 59ba3c8 commit 1184647
Showing 1 changed file with 3 additions and 21 deletions.
24 changes: 3 additions & 21 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,7 @@ def training_step(
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log(
'train_loss',
loss,
batch_size=batch_size,
on_step=True,
on_epoch=True,
)
self.log('train_loss', loss, batch_size=batch_size, on_step=True, on_epoch=True)
self.train_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.train_metrics.compute().items()},
Expand All @@ -313,13 +307,7 @@ def validation_step(
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log(
'val_loss',
loss,
batch_size=batch_size,
on_step=False,
on_epoch=True,
)
self.log('val_loss', loss, batch_size=batch_size, on_step=False, on_epoch=True)
self.val_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.val_metrics.compute().items()},
Expand Down Expand Up @@ -368,13 +356,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log(
'test_loss',
loss,
batch_size=batch_size,
on_step=False,
on_epoch=True,
)
self.log('test_loss', loss, batch_size=batch_size, on_step=False, on_epoch=True)
self.test_metrics(y_hat, y)
self.log_dict(
{f'{k}': v for k, v in self.test_metrics.compute().items()},
Expand Down

0 comments on commit 1184647

Please sign in to comment.