diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index b747985406..83763315d3 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -21,10 +21,10 @@ default_cfgs: Dict[str, Dict[str, Any]] = { 'sar_resnet31': { - 'mean': (.5, .5, .5), - 'std': (1., 1., 1.), + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), 'input_shape': (3, 32, 128), - 'vocab': VOCABS['legacy_french'], + 'vocab': VOCABS['french'], 'url': None, }, } @@ -50,28 +50,35 @@ class AttentionModule(nn.Module): def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None: super().__init__() - self.feat_conv = nn.Conv2d(feat_chans, attention_units, 3, padding=1) + self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1) # No need to add another bias since both tensors are summed together - self.state_conv = nn.Linear(state_chans, attention_units, bias=False) - self.attention_projector = nn.Linear(attention_units, 1, bias=False) + self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False) + self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False) - def forward(self, features: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor: - # shape (N, C, H, W) -> (N, attention_units, H, W) + def forward( + self, + features: torch.Tensor, # (N, C, H, W) + hidden_state: torch.Tensor, # (N, C) + ) -> torch.Tensor: + + H_f, W_f = features.shape[2:] + + # (N, feat_chans, H, W) --> (N, attention_units, H, W) feat_projection = self.feat_conv(features) - # shape (N, L, rnn_units) -> (N, L, attention_units) - state_projection = self.state_conv(hidden_state).unsqueeze(-1).unsqueeze(-1) - # (N, L, attention_units, H, W) - projection = torch.tanh(feat_projection.unsqueeze(1) + state_projection) - # (N, L, H, W, 1) - attention = self.attention_projector(projection.permute(0, 1, 3, 4, 2)) - # shape (N, L, H, W, 1) -> (N, L, H * W) - attention = torch.flatten(attention, 2) - attention = torch.softmax(attention, -1) - # shape (N, L, H * W) -> (N, L, 1, H, W) - attention = attention.reshape(-1, hidden_state.shape[1], features.shape[-2], features.shape[-1]) - - # (N, L, C) - return (features.unsqueeze(1) * attention.unsqueeze(2)).sum(dim=(3, 4)) + # (N, state_chans, 1, 1) --> (N, attention_units, 1, 1) + hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1) + state_projection = self.state_conv(hidden_state) + state_projection = state_projection.expand(-1, -1, H_f, W_f) + # (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f) + attention_weights = torch.tanh(feat_projection + state_projection) + # (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f) + attention_weights = self.attention_projector(attention_weights) + B, C, H, W = attention_weights.size() + + # (N, H, W) --> (N, 1, H, W) + attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W) + # fuse features and attention weights (N, C) + return (features * attention_weights).sum(dim=(2, 3)) class SARDecoder(nn.Module): @@ -83,7 +90,6 @@ class SARDecoder(nn.Module): vocab_size: number of classes in the model alphabet embedding_units: number of hidden embedding units attention_units: number of hidden attention units - num_decoder_layers: number of LSTM layers to stack """ def __init__( @@ -93,19 +99,21 @@ def __init__( vocab_size: int, embedding_units: int, attention_units: int, - num_decoder_layers: int = 2, feat_chans: int = 512, dropout_prob: float = 0., ) -> None: super().__init__() self.vocab_size = vocab_size - self.rnn = nn.LSTM(rnn_units, rnn_units, 2, batch_first=True, dropout=dropout_prob) - self.embed = nn.Embedding(self.vocab_size + 1, embedding_units) - self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units) - self.output_dense = nn.Linear(2 * rnn_units, vocab_size + 1) self.max_length = max_length + self.embed = nn.Linear(self.vocab_size + 1, embedding_units) + self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1) + self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units) + self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units) + self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1) + self.dropout = nn.Dropout(dropout_prob) + def forward( self, features: torch.Tensor, # (N, C, H, W) @@ -113,45 +121,46 @@ def forward( gt: Optional[torch.Tensor] = None, # (N, L) ) -> torch.Tensor: - if gt is None: - # Initialize with the index of virtual START symbol (placed after so that the one-hot is only zeros) - symbol = torch.zeros(features.shape[0], device=features.device, dtype=torch.long) - # (N, embedding_units) - symbol = self.embed(symbol) - # (N, L, embedding_units) - symbol = symbol.unsqueeze(1).expand(-1, self.max_length + 1, -1) - else: - # (N, L) --> (N, L, embedding_units) - symbol = self.embed(gt) - # (N, L + 1, embedding_units) - symbol = torch.cat((holistic.unsqueeze(1), symbol), dim=1) - - # (N, L, vocab_size + 1) - char_logits = torch.zeros( - (features.shape[0], self.max_length + 1, self.vocab_size + 1), - device=features.device, - dtype=features.dtype, - ) - - decoding_iter = self.max_length + 1 if gt is None else 1 - for t in range(decoding_iter): - # (N, L + 1, rnn_units) - logits = self.rnn(symbol)[0] - # (N, L + 1, C) - glimpse = self.attention_module(features, logits) - # (N, L + 1, 2 * rnn_units) - logits = torch.cat((logits, glimpse), -1) - # (N, L + 1, vocab_size + 1) - decoded = self.output_dense(logits) - if gt is None: - char_logits[:, t] = decoded[:, t + 1] - if t < decoding_iter - 1: - # update symbol with predicted logits for t + 1 step - symbol[:, t + 2] = self.embed(decoded[:, t + 1].argmax(-1)) + if gt is not None: + gt_embedding = self.embed_tgt(gt) + + logits_list: List[torch.Tensor] = [] + + for t in range(self.max_length + 1): # 32 + if t == 0: + # step to init the first states of the LSTMCell + hidden_state_init = cell_state_init = \ + torch.zeros(features.size(0), features.size(1), device=features.device) + hidden_state, cell_state = hidden_state_init, cell_state_init + prev_symbol = holistic + elif t == 1: + # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros + # (N, vocab_size + 1) --> (N, embedding_units) + prev_symbol = torch.zeros(features.size(0), self.vocab_size + 1, device=features.device) + prev_symbol = self.embed(prev_symbol) else: - char_logits = decoded[:, 1:] - - return char_logits + if gt is not None: + #(N, embedding_units) -2 because of and (same) + prev_symbol = self.embed(gt_embedding[:, t - 2]) + else: + # -1 to start at timestep where prev_symbol was initialized + index = logits_list[t - 1].argmax(-1) + # (1, embedding_units) + prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1) + + # (N, C), (N, C) take the last hidden state and cell state from current timestep + hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init)) + hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state)) + # (N, C, H, W), (N, C) --> (N, C) + glimpse = self.attention_module(features, hidden_state) + # (N, C), (N, C) --> (N, 2 * C) + logits = torch.cat([hidden_state, glimpse], dim=1) + logits = self.dropout(logits) + # (N, vocab_size + 1) + logits_list.append(self.output_dense(logits)) + + # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1) + return torch.stack(logits_list[1:]).permute(1, 0, 2) class SAR(nn.Module, RecognitionModel): @@ -165,7 +174,6 @@ class SAR(nn.Module, RecognitionModel): embedding_units: number of embedding units attention_units: number of hidden units in attention module max_length: maximum word length handled by the model - num_decoders: number of LSTM to stack in decoder layer dropout_prob: dropout probability of the encoder LSTM cfg: default setup dict of the model """ @@ -176,9 +184,8 @@ def __init__( vocab: str, rnn_units: int = 512, embedding_units: int = 512, - attention_units: int = 512, + attention_units: int = 64, max_length: int = 30, - num_decoders: int = 2, dropout_prob: float = 0., input_shape: Tuple[int, int, int] = (3, 32, 128), cfg: Optional[Dict[str, Any]] = None, @@ -200,10 +207,8 @@ def __init__( self.feat_extractor.train() self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob) - - self.decoder = SARDecoder( - rnn_units, max_length, len(vocab), embedding_units, attention_units, num_decoders, out_shape[1], - ) + self.decoder = SARDecoder(rnn_units, self.max_length, len(self.vocab), + embedding_units, attention_units, dropout_prob=dropout_prob) self.postprocessor = SARPostProcessor(vocab=vocab) @@ -217,9 +222,10 @@ def forward( features = self.feat_extractor(x)['features'] # Vertical max pooling --> (N, C, W) - pooled_features = features.max(dim=-2).values + pooled_features = F.max_pool2d(features, kernel_size=(features.shape[2], 1), stride=(1, 1)) + pooled_features = pooled_features.squeeze(2) # (N, W, C) - pooled_features = pooled_features.permute((0, 2, 1)) + pooled_features = pooled_features.permute(0, 2, 1).contiguous() # (N, C) encoded = self.encoder(pooled_features) if target is not None: @@ -301,6 +307,7 @@ def _sar( backbone_fn: Callable[[bool], nn.Module], layer: str, pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, **kwargs: Any ) -> SAR: @@ -323,7 +330,10 @@ def _sar( model = SAR(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]['url']) + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if _cfg['vocab'] != default_cfgs[arch]['vocab'] else None + load_pretrained_params(model, default_cfgs[arch]['url'], ignore_keys=_ignore_keys) return model @@ -345,4 +355,14 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: text recognition architecture """ - return _sar('sar_resnet31', pretrained, resnet31, '10', **kwargs) + return _sar( + 'sar_resnet31', + pretrained, + resnet31, + '10', + ignore_keys=[ + 'decoder.embed.weight', 'decoder.embed_tgt.weight', + 'decoder.output_dense.weight', 'decoder.output_dense.bias' + ], + **kwargs, + )