Skip to content

Commit fe663ab

Browse files
committed
codeformat fixes
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent a2e0a9e commit fe663ab

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

monai/transforms/spatial/array.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def __init__(
462462
self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode)
463463
self.align_corners = align_corners
464464
self.dtype = dtype
465-
self._rotation_matrix: Optional[np.ndarray] = None
465+
self._rotation_matrix: Optional[NdarrayOrTensor] = None
466466

467467
def __call__(
468468
self,
@@ -511,7 +511,7 @@ def __call__(
511511
corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape(
512512
(len(im_shape), -1)
513513
)
514-
corners = transform[:-1, :-1] @ corners
514+
corners = transform[:-1, :-1] @ corners # type: ignore
515515
output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int)
516516
shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist())
517517
transform = shift @ transform @ shift_1
@@ -532,7 +532,7 @@ def __call__(
532532
out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype)
533533
return out
534534

535-
def get_rotation_matrix(self) -> Optional[np.ndarray]:
535+
def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]:
536536
"""
537537
Get the most recently applied rotation matrix
538538
This is not thread-safe.

monai/transforms/utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,8 @@ def create_rotate(
602602
return _create_rotate(
603603
spatial_dims=spatial_dims,
604604
radians=radians,
605-
sin_func=lambda th: torch.sin(torch.as_tensor(th)),
606-
cos_func=lambda th: torch.cos(torch.as_tensor(th)),
605+
sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32)),
606+
cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32)),
607607
array_func=torch.as_tensor,
608608
)
609609
raise ValueError("backend {} is not supported".format(backend))
@@ -676,7 +676,6 @@ def create_shear(
676676
677677
"""
678678
if look_up_option(backend, TransformBackends) == TransformBackends.NUMPY:
679-
array_func = np.array
680679
return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=np.array)
681680
if look_up_option(backend, TransformBackends) == TransformBackends.TORCH:
682681
return _create_shear(spatial_dims=spatial_dims, coefs=coefs, array_func=torch.as_tensor)
@@ -759,7 +758,7 @@ def _create_translate(
759758
affine = eye_func(spatial_dims + 1)
760759
for i, a in enumerate(shift[:spatial_dims]):
761760
affine[i, spatial_dims] = a
762-
return array_func(affine)
761+
return array_func(affine) # type: ignore
763762

764763

765764
def generate_spatial_bounding_box(

0 commit comments

Comments
 (0)