Skip to content

Commit

Permalink
NumPy 2 support (#2151)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jul 10, 2024
1 parent 61635cd commit d27ee30
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ quote-style = "single"
skip-magic-trailing-comma = true

[tool.ruff.lint]
extend-select = ["D", "I", "UP"]
extend-select = ["D", "I", "NPY201", "UP"]

[tool.ruff.lint.per-file-ignores]
"docs/**" = ["D"]
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _load_target(self, filename: str) -> Tensor:
target mask
"""
with rasterio.open(os.path.join(self.root, 'train_agbm', filename), 'r') as src:
arr: np.typing.NDArray[np.float_] = src.read()
arr: np.typing.NDArray[np.float64] = src.read()

target = torch.from_numpy(arr).float()
return target
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/dfc2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor:
the image
"""
with rasterio.open(path) as f:
array: np.typing.NDArray[np.float_] = f.read(
array: np.typing.NDArray[np.float64] = f.read(
out_shape=shape, out_dtype='float32', resampling=Resampling.bilinear
)
tensor = torch.from_numpy(array)
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,12 @@ def plot(

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4))

def apply_cmap(arr: 'np.typing.NDArray[Any]') -> 'np.typing.NDArray[np.float_]':
def apply_cmap(
arr: 'np.typing.NDArray[Any]',
) -> 'np.typing.NDArray[np.float64]':
# Color 0 as black, while applying default color map for the class indices.
cmap = plt.get_cmap('viridis')
im: np.typing.NDArray[np.float_] = cmap(arr / len(self.class_map))
im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map))
im[arr == 0] = 0
return im

Expand Down

0 comments on commit d27ee30

Please sign in to comment.