Skip to content

Commit

Permalink
Add DeDoDe (clean version) (#2835)
Browse files Browse the repository at this point in the history
* Add dedode

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ducha-aiki and pre-commit-ci[bot] authored Mar 13, 2024
1 parent a14319c commit b6ca41c
Show file tree
Hide file tree
Showing 27 changed files with 1,991 additions and 12 deletions.
3 changes: 3 additions & 0 deletions docs/source/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Local Features (Detector and Descriptors together)
.. autoclass:: SOLD2_detector
:members: forward

.. autoclass:: DeDoDe
:members: forward, from_pretrained, describe, detect

.. autoclass:: DISK
:members: forward, from_pretrained, heatmap_and_dense_descriptors

Expand Down
20 changes: 14 additions & 6 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ @article{tyszkiewicz2020disk
year={2020}
}

@inproceedings{edstedt2024dedode,
title={{DeDoDe: Detect, Don't Describe --- Describe, Don't Detect for Local Feature Matching}},
author = {Johan Edstedt and Georg Bökman and Mårten Wadenbäck and Michael Felsberg},
booktitle={2024 International Conference on 3D Vision (3DV)},
year={2024}
}

@inproceedings{he2010guided,
title = {Guided Image Filtering},
booktitle = {Proceedings of the 11th European Conference on Computer Vision: Part I},
Expand Down Expand Up @@ -368,12 +375,6 @@ @inproceedings{barath2020magsac++
year={2020}
}

@inproceedings{wei2023generalized,
author = {Wei, Tong and Patel, Yash and Shekhovtsov, Alexander and Matas, Jiri and Barath, Daniel},
title = {Generalized Differentiable RANSAC},
booktitle = {ICCV},
year = {2023}
}

@inproceedings{shin2017,
title={JPEG-resistant Adversarial Images},
Expand All @@ -390,3 +391,10 @@ @inproceedings{reich2024
booktitle={IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
year={2024}
}

@inproceedings{wei2023generalized,
author = {Wei, Tong and Patel, Yash and Shekhovtsov, Alexander and Matas, Jiri and Barath, Daniel},
title = {Generalized Differentiable RANSAC},
booktitle = {ICCV},
year = {2023}
}
2 changes: 2 additions & 0 deletions kornia/feature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .affine_shape import LAFAffineShapeEstimator, LAFAffNetShapeEstimator, PatchAffineShapeEstimator
from .dedode import DeDoDe
from .defmo import DeFMO
from .disk import DISK, DISKFeatures
from .hardnet import HardNet, HardNet8
Expand Down Expand Up @@ -156,6 +157,7 @@
"perspective_transform_lafs",
"SOLD2_detector",
"SOLD2",
"DeDoDe",
"DISK",
"DISKFeatures",
"LightGlue",
Expand Down
3 changes: 3 additions & 0 deletions kornia/feature/dedode/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dedode import DeDoDe

__all__ = ["DeDoDe"]
99 changes: 99 additions & 0 deletions kornia/feature/dedode/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, Optional, Tuple

import torch
from torch import nn

from kornia.core import Tensor


class Decoder(nn.Module):
def __init__(self, layers: Any, *args, super_resolution: bool = False, num_prototypes: int = 1, **kwargs) -> None: # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
self.layers = layers
self.scales = self.layers.keys()
self.super_resolution = super_resolution
self.num_prototypes = num_prototypes

def forward(
self, features: Tensor, context: Optional[Tensor] = None, scale: Optional[int] = None
) -> Tuple[Tensor, Optional[Tensor]]:
if context is not None:
features = torch.cat((features, context), dim=1)
stuff = self.layers[scale](features)
logits, context = stuff[:, : self.num_prototypes], stuff[:, self.num_prototypes :]
return logits, context


class ConvRefiner(nn.Module):
def __init__( # type: ignore[no-untyped-def]
self,
in_dim=6,
hidden_dim=16,
out_dim=2,
dw=True,
kernel_size=5,
hidden_blocks=5,
amp=True,
residual=False,
amp_dtype=torch.float16,
):
super().__init__()
self.block1 = self.create_block(
in_dim,
hidden_dim,
dw=False,
kernel_size=1,
)
self.hidden_blocks = nn.Sequential(
*[
self.create_block(
hidden_dim,
hidden_dim,
dw=dw,
kernel_size=kernel_size,
)
for hb in range(hidden_blocks)
]
)
self.hidden_blocks = self.hidden_blocks
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
self.amp = amp
self.amp_dtype = amp_dtype
self.residual = residual

def create_block( # type: ignore[no-untyped-def]
self,
in_dim,
out_dim,
dw=True,
kernel_size=5,
bias=True,
norm_type=nn.BatchNorm2d,
):
num_groups = 1 if not dw else in_dim
if dw:
if out_dim % in_dim != 0:
raise Exception("outdim must be divisible by indim for depthwise")
conv1 = nn.Conv2d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=num_groups,
bias=bias,
)
norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
relu = nn.ReLU(inplace=True)
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
return nn.Sequential(conv1, norm, relu, conv2)

def forward(self, feats: Tensor) -> Tensor:
b, c, hs, ws = feats.shape
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
x0 = self.block1(feats)
x = self.hidden_blocks(x0)
if self.residual:
x = (x + x0) / 1.4
x = self.out_conv(x)
return x
194 changes: 194 additions & 0 deletions kornia/feature/dedode/dedode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F

from kornia.core import Module, Tensor
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.enhance.normalize import Normalize
from kornia.geometry.conversions import denormalize_pixel_coordinates
from kornia.utils.helpers import map_location_to_cpu

from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector
from .utils import sample_keypoints

urls: Dict[str, Dict[str, str]] = {
"detector": {
"L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
"L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth",
"L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth",
},
"descriptor": {
"B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
"B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth",
"B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth",
"G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth",
"G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth",
},
}


class DeDoDe(Module):
r"""Module which detects and/or describes local features in an image using the DeDode method.
See :cite:`edstedt2024dedode` for details.
.. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]).
Example:
>>> dedode = DeDoDe.from_pretrained(detector_weights="L-upright", descriptor_weights="B-upright")
>>> images = torch.randn(1, 3, 256, 256)
>>> keypoints, scores = dedode.detect(images)
>>> descriptions = dedode.describe(images, keypoints = keypoints)
>>> keypoints, scores, features = dedode(images) # alternatively do both
"""

# TODO: implement steerers and mnn matchers
def __init__(
self, detector_model: str = "L", descriptor_model: str = "G", amp_dtype: torch.dtype = torch.float16
) -> None:
super().__init__()
self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype)
self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype)
self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))

def forward(
self,
images: Tensor,
n: Optional[int] = 10_000,
apply_imagenet_normalization: bool = True,
pad_if_not_divisible: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Detects and describes keypoints in the input images.
Args:
images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images.
n: The number of keypoints to detect.
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
Returns:
keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range,
unlike `.detect()` function
scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
DIM is 256 for B and 512 for G.
"""
if apply_imagenet_normalization:
images = self.normalizer(images)
B, C, H, W = images.shape
if pad_if_not_divisible:
h, w = images.shape[2:]
pd_h = 14 - h % 14 if h % 14 > 0 else 0
pd_w = 14 - w % 14 if w % 14 > 0 else 0
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False)
return denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions

@torch.inference_mode()
def detect(
self,
images: Tensor,
n: Optional[int] = 10_000,
apply_imagenet_normalization: bool = True,
pad_if_not_divisible: bool = True,
crop_h: Optional[int] = None,
crop_w: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
"""Detects keypoints in the input images.
Args:
images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
n: The number of keypoints to detect.
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
crop_h: The height of the crop to be used for detection. If None, the full image is used.
crop_w: The width of the crop to be used for detection. If None, the full image is used.
Returns:
keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints,
normalized to the range :math:`[-1, 1]`.
scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
"""
KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
self.train(False)
if pad_if_not_divisible:
h, w = images.shape[2:]
pd_h = 14 - h % 14 if h % 14 > 0 else 0
pd_w = 14 - w % 14 if w % 14 > 0 else 0
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
if apply_imagenet_normalization:
images = self.normalizer(images)
B, C, H, W = images.shape
logits = self.detector.forward(images)
if crop_h is not None and crop_w is not None:
logits = logits[..., :crop_h, :crop_w]
H, W = crop_h, crop_w
scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W)
keypoints, confidence = sample_keypoints(scoremap, num_samples=n)
return keypoints, confidence

@torch.inference_mode()
def describe(
self, images: Tensor, keypoints: Optional[Tensor] = None, apply_imagenet_normalization: bool = True
) -> Tensor:
"""Describes keypoints in the input images. If keypoints are not provided, returns the dense descriptors.
Args:
images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints.
apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
Returns:
descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`.
"""
KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
B, C, H, W = images.shape
if keypoints is not None:
KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"])
if apply_imagenet_normalization:
images = self.normalizer(images)
self.train(False)
descriptions = self.descriptor.forward(images)
if keypoints is not None:
described_keypoints = F.grid_sample(
descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False
)[:, :, 0].mT
return described_keypoints
return descriptions

@classmethod
def from_pretrained(
cls,
detector_weights: str = "L-upright",
descriptor_weights: str = "G-upright",
amp_dtype: torch.dtype = torch.float16,
) -> Module:
r"""Loads a pretrained model.
Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints
only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and
is less precise but detects keypoints everywhere where they are matchable. The difference is especially
pronounced on thin structures and on edges of objects.
Args:
detector_weights: The weights to load for the detector. One of 'L-upright', 'L-C4', 'L-SO2'.
descriptor_weights: The weights to load for the descriptor.
One of 'B-upright', 'B-C4', 'B-SO2', 'G-upright', 'G-C4'.
checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'.
amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32.
Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS
Returns:
The pretrained model.
"""
model: DeDoDe = cls(
detector_model=detector_weights[0], descriptor_model=descriptor_weights[0], amp_dtype=amp_dtype
)
model.detector.load_state_dict(
torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=map_location_to_cpu)
)
model.descriptor.load_state_dict(
torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=map_location_to_cpu)
)
model.eval()
return model
Loading

0 comments on commit b6ca41c

Please sign in to comment.