Skip to content

Commit

Permalink
hotfix tensor dimensionality assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-donike committed May 16, 2024
1 parent 3c7970d commit bfb4dbe
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ old/*
ablation.py
images/
images/*
images_comparison
images_comparison/*
comp_s2_s2naip.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 0 additions & 1 deletion opensr_model/srmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def _tensor_encode(self,X: torch.Tensor):
if self.encode_conditioning==True and self.sr_type=="SISR":
# try to upsample->encode conditioning
X_int = torch.nn.functional.interpolate(X, size=(X.shape[-1]*4,X.shape[-1]*4), mode='bilinear', align_corners=False)
print(X_int.shape)
# encode conditioning
X_enc = self.model.first_stage_model.encode(X_int).sample()
# move to same device as the model
Expand Down
13 changes: 6 additions & 7 deletions opensr_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def linear_transform_6b(t_input,stage="norm"):

return t_output

def assert_tensor_validity(self,tensor):
def assert_tensor_validity(tensor):

# ASSERT BATCH DIMENSION
# if unbatched, add batch dimension
Expand All @@ -142,30 +142,29 @@ def assert_tensor_validity(self,tensor):
# Padding for height and width needs to be added to both sides of the dimension
# The pad has the format (left, right, top, bottom)
padding = (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2)
self.padding = padding
padding = padding

# Apply symmetric padding
tensor = torch.nn.functional.pad(tensor, padding, mode='reflect')

else: # save padding with 0s
padding = (0,0,0,0)
self.padding = padding
padding = padding

return tensor,padding



def revert_padding(self,tensor,padding):
def revert_padding(tensor,padding):
left, right, top, bottom = padding
# account for 4x upsampling Factor
left, right, top, bottom = left*4, right*4, top*4, bottom*4
# Calculate the indices to slice from the padded tensor
start_height = top
end_height = tensor.size(-2) - bottom
start_width = left
end_width = tensor.size(-1) - right

# account for 4x upsampling Factor
start_height,end_height,start_width,end_width = start_height*4,end_height*4,start_width*4,end_width*4

# Slice the tensor to remove padding
unpadded_tensor = tensor[:,:, start_height:end_height, start_width:end_width]
return unpadded_tensor
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='opensr-model',
version='0.2.11',
version='0.2.12',
author = "Simon Donike, Cesar Aybar, Luis Gomez Chova, Freddie Kalaitzis",
author_email = "accounts@donike.net",
description = "ESA OpenSR Diffusion model package for Super-Resolution of Senintel-2 Imagery",
Expand Down

0 comments on commit bfb4dbe

Please sign in to comment.