Skip to content

Commit

Permalink
feat: add dice loss in linknet (#816)
Browse files Browse the repository at this point in the history
* feat: add dice loss in linknet

* fix: typing

* fix: requested changes
  • Loading branch information
charlesmindee authored Feb 14, 2022
1 parent 166001c commit 612e4f8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
21 changes: 8 additions & 13 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_target(
self,
target: List[np.ndarray],
output_shape: Tuple[int, int],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray]:

if any(t.dtype != np.float32 for t in target):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
Expand All @@ -119,7 +119,6 @@ def build_target(

if self.assume_straight_pages:
seg_target = np.zeros(target_shape, dtype=bool)
edge_mask = np.zeros(target_shape, dtype=bool)
else:
seg_target = np.zeros(target_shape, dtype=np.uint8)

Expand All @@ -144,7 +143,12 @@ def build_target(
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)
polys = [None] * abs_boxes.shape[0] # Unused
polys = np.stack([
abs_boxes[:, [0, 1]],
abs_boxes[:, [0, 3]],
abs_boxes[:, [2, 3]],
abs_boxes[:, [2, 1]],
], axis=1)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
Expand All @@ -159,19 +163,10 @@ def build_target(
if box.shape == (4, 2):
box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])]
seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True
# top edge
edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True
# bot edge
edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True
# left edge
edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True
# right edge
edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True

# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
seg_target = seg_target.transpose(0, 3, 1, 2)
seg_mask = seg_mask.transpose(0, 3, 1, 2)
edge_mask = edge_mask.transpose(0, 3, 1, 2)

return seg_target, seg_mask, edge_mask
return seg_target, seg_mask
21 changes: 10 additions & 11 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,33 @@ def compute_loss(
self,
out_map: torch.Tensor,
target: List[np.ndarray],
edge_factor: float = 2.,
) -> torch.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
<https://github.com/tensorflow/addons/>`_.
Args:
out_map: output feature map of the model of shape (N, 1, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
edge_factor: boost factor for box edges (in case of BCE)
Returns:
A loss tensor
"""
seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]
seg_target, seg_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask)
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
if edge_factor > 0:
edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device)

# Get the cross_entropy for each entry
loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')
# BCE loss
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')

# Dice loss
prob_map = torch.nn.functional.sigmoid(out_map)
inter = (prob_map[seg_mask] * seg_target[seg_mask]).sum()
cardinality = (prob_map[seg_mask] + seg_target[seg_mask]).sum()
dice_loss = 1 - 2 * inter / (cardinality + 1e-8)

# Compute BCE loss with highlighted edges
if edge_factor > 0:
loss = ((1 + (edge_factor - 1) * edge_mask) * loss)
# Only consider contributions overlaping the mask
return loss[seg_mask].mean()
return bce_loss[seg_mask].mean() + dice_loss


def _linknet(
Expand Down
23 changes: 9 additions & 14 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,37 +139,32 @@ def compute_loss(
self,
out_map: tf.Tensor,
target: List[np.ndarray],
edge_factor: float = 2.,
) -> tf.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
<https://github.com/tensorflow/addons/>`_.
Args:
out_map: output feature map of the model of shape N x H x W x 1
target: list of dictionary where each dict has a `boxes` and a `flags` entry
edge_factor: boost factor for box edges (in case of BCE)
Returns:
A loss tensor
"""
seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3])
seg_target, seg_mask = self.build_target(target, out_map.shape[1:3])

seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
if edge_factor > 0:
edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool)

# Get the cross_entropy for each entry
loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]
# BCE loss
bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]

# Compute BCE loss with highlighted edges
if edge_factor > 0:
loss = tf.math.multiply(
1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype),
loss
)
# Dice loss
prob_map = tf.math.sigmoid(out_map)
inter = tf.math.reduce_sum(prob_map[seg_mask] * seg_target[seg_mask])
cardinality = tf.math.reduce_sum(prob_map[seg_mask] + seg_target[seg_mask])
dice_loss = 1 - 2 * inter / (cardinality + 1e-8)

return tf.reduce_mean(loss[seg_mask])
return tf.math.reduce_mean(bce_loss[seg_mask]) + dice_loss

def call(
self,
Expand Down

0 comments on commit 612e4f8

Please sign in to comment.