Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix mistake in FASTConvLayer and tf reparameterization #1506

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doctr/models/classification/textnet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0",
},
}

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-9e605bd8.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-4784b292.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-2c3f3265.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
},
}

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class _FAST(BaseModel):

min_size_box: int = 3
assume_straight_pages: bool = True
shrink_ratio = 0.1
shrink_ratio = 0.4

def build_target(
self,
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-7bee86e2.pt&src=0",
"url": None,
},
"fast_small": {
"input_shape": (3, 1024, 1024),
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
dropout_prob: float = 0.1,
pooling_size: int = 9,
pooling_size: int = 4, # different from paper performs better on close text-rich images
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = {},
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
dropout_prob: float = 0.1,
pooling_size: int = 9,
pooling_size: int = 4, # different from paper performs better on close text-rich images
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = {},
Expand Down
3 changes: 1 addition & 2 deletions doctr/models/modules/layers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
horizontal_outputs = (
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
)
id_out = self.rbr_identity(x) if self.rbr_identity is not None and self.ver_bn is not None else 0
id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0

return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

Expand Down Expand Up @@ -155,7 +155,6 @@ def reparameterize_layer(self):
)
self.fused_conv.weight.data = kernel
self.fused_conv.bias.data = bias # type: ignore[union-attr]
self.deploy = True
for para in self.parameters():
para.detach_()
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
Expand Down
14 changes: 6 additions & 8 deletions doctr/models/modules/layers/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
if self.hor_bn is not None and self.hor_conv is not None
else 0
)
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None and self.ver_bn is not None else 0
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0

return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

Expand All @@ -110,14 +110,14 @@ def _identity_to_conv(
return 0, 0
if not hasattr(self, "id_tensor"):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 0, 0] = 1
kernel_value[0, 0, i % input_dim, i] = 1
id_tensor = tf.constant(kernel_value, dtype=tf.float32)
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
kernel = self.id_tensor
std = tf.sqrt(identity.moving_variance + identity.epsilon)
t = tf.reshape(identity.gamma / std, (-1, 1, 1, 1))
t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std

def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
Expand All @@ -138,18 +138,16 @@ def _get_equivalent_kernel_bias(self):
else:
kernel_1xn, bias_1xn = 0, 0
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
if not isinstance(kernel_id, int):
kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
return kernel_mxn, bias_mxn

def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
kernel_height, kernel_width = self.converted_ks
height, width = kernel.shape[2:]
height, width = kernel.shape[:2]
pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
return tf.pad(kernel, [[0, 0], [0, 0], [pad_top_down, pad_top_down], [pad_left_right, pad_left_right]])
return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])

def reparameterize_layer(self):
kernel, bias = self._get_equivalent_kernel_bias()
Expand Down
10 changes: 10 additions & 0 deletions tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ def test_detection_zoo(arch_name):
assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps)


def test_fast_reparameterization():
dummy_input = torch.rand((2, 3, 1024, 1024), dtype=torch.float32)
base_model = detection.fast_tiny(pretrained=True, exportable=True).eval()
base_out = base_model(dummy_input)["logits"]
rep_model = reparameterize(base_model)
rep_out = rep_model(dummy_input)["logits"]
diff = base_out - rep_out
assert diff.mean() < 5e-2 and diff.mean() < 5e-2


def test_erode():
x = torch.zeros((1, 1, 3, 3))
x[..., 1, 1] = 1
Expand Down
10 changes: 10 additions & 0 deletions tests/tensorflow/test_models_detection_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def test_detection_zoo_error():
_ = detection.zoo.detection_predictor("my_fancy_model", pretrained=False)


def test_fast_reparameterization():
dummy_input = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1)
base_model = detection.fast_tiny(pretrained=True, exportable=True)
base_out = base_model(dummy_input, training=False)["logits"]
rep_model = reparameterize(base_model)
rep_out = rep_model(dummy_input, training=False)["logits"]
diff = base_out - rep_out
assert tf.math.reduce_mean(diff) < 5e-2 and tf.math.reduce_std(diff) < 5e-2


def test_erode():
x = np.zeros((1, 3, 3, 1), dtype=np.float32)
x[:, 1, 1] = 1
Expand Down
Loading