Skip to content

Commit

Permalink
Cast uint16 tensors to int32, leave float32 tensors as they are (micr…
Browse files Browse the repository at this point in the history
  • Loading branch information
khdlr authored Apr 11, 2022
1 parent 110c919 commit ba820e3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchgeo/datasets/sen12ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def _load_raster(self, filename: str, source: str) -> Tensor:
"{}_{}_{}_{}_{}".format(*parts),
)
) as f:
array = f.read().astype(np.int32)
array = f.read()
if array.dtype == np.uint16:
array = array.astype(np.int32)
tensor = torch.from_numpy(array)
return tensor

Expand Down

0 comments on commit ba820e3

Please sign in to comment.