Skip to content

Commit

Permalink
segresnet_ds lower peak GPU mem (#7066)
Browse files Browse the repository at this point in the history
Reduces peak GPU mem usage of segresnet_ds(), by releasing buffers
earlier.

Signed-off-by: myron <amyronenko@nvidia.com>
  • Loading branch information
myron authored Sep 29, 2023
1 parent 14fcf72 commit 317ef1f
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions monai/networks/nets/segresnet_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def __init__(

def forward(self, x):
identity = x
x = self.conv1(self.act1(self.norm1(x)))
x = self.conv2(self.act2(self.norm2(x)))
x = self.conv2(self.act2(self.norm2(self.conv1(self.act1(self.norm1(x))))))
x += identity
return x

Expand Down Expand Up @@ -408,7 +407,7 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens
i = 0
for level in self.up_layers:
x = level["upsample"](x)
x = x + x_down[i]
x += x_down.pop(0)
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
Expand Down

0 comments on commit 317ef1f

Please sign in to comment.