Skip to content

Commit

Permalink
Update detection XAI algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
GalyaZalesskaya committed Nov 8, 2023
1 parent a5d12b7 commit 3fc39a2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,9 @@ def func(
else:
cls_scores = self._get_cls_scores_from_feature_map(feature_map)

# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
# it would highlight one of the class maps as a background class
if self.use_cls_softmax and self._num_cls_out_channels > 1:
cls_scores = [torch.softmax(t, dim=1) for t in cls_scores]

batch_size, _, height, width = cls_scores[-1].size()
middle_idx = len(cls_scores) // 2
# resize to the middle feature map
batch_size, _, height, width = cls_scores[middle_idx].size()
saliency_maps = torch.empty(batch_size, self._num_cls_out_channels, height, width)
for batch_idx in range(batch_size):
cls_scores_anchorless = []
Expand All @@ -82,6 +79,11 @@ def func(
)
saliency_maps[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0)

# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
# it would highlight one of the class maps as a background class
if self.use_cls_softmax:
saliency_maps[0] = torch.stack([torch.softmax(t, dim=1) for t in saliency_maps[0]])

if self._norm_saliency_maps:
saliency_maps = saliency_maps.reshape((batch_size, self._num_cls_out_channels, -1))
saliency_maps = self._normalize_map(saliency_maps)
Expand Down
36 changes: 18 additions & 18 deletions tests/unit/algorithms/detection/test_xai_detection_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,31 @@

class TestExplainMethods:
ref_saliency_shapes = {
"MobileNetV2-ATSS": (2, 4, 4),
"ResNeXt101-ATSS": (2, 4, 4),
"MobileNetV2-ATSS": (2, 13, 13),
"ResNeXt101-ATSS": (2, 13, 13),
"SSD": (81, 13, 13),
"YOLOX-TINY": (80, 13, 13),
"YOLOX-S": (80, 13, 13),
"YOLOX-L": (80, 13, 13),
"YOLOX-X": (80, 13, 13),
"YOLOX-TINY": (80, 26, 26),
"YOLOX-S": (80, 26, 26),
"YOLOX-L": (80, 26, 26),
"YOLOX-X": (80, 26, 26),
}

ref_saliency_vals_det = {
"MobileNetV2-ATSS": np.array([67, 216, 255, 57], dtype=np.uint8),
"ResNeXt101-ATSS": np.array([75, 214, 229, 173], dtype=np.uint8),
"YOLOX-TINY": np.array([80, 28, 42, 53, 49, 68, 72, 75, 69, 57, 65, 6, 157], dtype=np.uint8),
"YOLOX-S": np.array([75, 178, 151, 159, 150, 148, 144, 144, 147, 144, 147, 142, 189], dtype=np.uint8),
"YOLOX-L": np.array([43, 28, 0, 6, 7, 19, 22, 17, 14, 18, 25, 7, 34], dtype=np.uint8),
"YOLOX-X": np.array([255, 144, 83, 76, 83, 86, 82, 90, 91, 93, 110, 104, 83], dtype=np.uint8),
"SSD": np.array([119, 72, 118, 35, 39, 30, 31, 31, 36, 27, 44, 23, 61], dtype=np.uint8),
"MobileNetV2-ATSS": np.array([34, 67, 148, 132, 172, 147, 146, 155, 167, 159], dtype=np.uint8),
"ResNeXt101-ATSS": np.array([52, 75, 68, 76, 89, 94, 101, 111, 125, 123], dtype=np.uint8),
"YOLOX-TINY": np.array([177, 94, 147, 147, 161, 162, 164, 164, 163, 166], dtype=np.uint8),
"YOLOX-S": np.array([158, 170, 180, 158, 152, 148, 153, 153, 148, 145], dtype=np.uint8),
"YOLOX-L": np.array([255, 80, 97, 88, 73, 71, 72, 76, 75, 76], dtype=np.uint8),
"YOLOX-X": np.array([185, 218, 189, 103, 83, 70, 62, 66, 66, 67], dtype=np.uint8),
"SSD": np.array([255, 178, 212, 90, 93, 79, 79, 80, 87, 83], dtype=np.uint8),
}

ref_saliency_vals_det_wo_postprocess = {
"MobileNetV2-ATSS": -0.10465062,
"ResNeXt101-ATSS": -0.073549636,
"MobileNetV2-ATSS": -0.014513552,
"ResNeXt101-ATSS": -0.055565584,
"YOLOX-TINY": 0.04948914,
"YOLOX-S": 0.01133332,
"YOLOX-L": 0.01870133,
"YOLOX-S": 0.011557617,
"YOLOX-L": 0.020231,
"YOLOX-X": 0.0043506604,
"SSD": 0.6629989,
}
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_saliency_map_det(self, template):
assert saliency_maps[0].ndim == 3
assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name]
# convert to int16 in case of negative value difference
actual_sal_vals = saliency_maps[0][0][0].astype(np.int16)
actual_sal_vals = saliency_maps[0][0][0][:10].astype(np.int16)
ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.uint8)
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

Expand Down

0 comments on commit 3fc39a2

Please sign in to comment.