Skip to content

Commit

Permalink
[Fix] PyTorch SAR_Resnet31 implementation (#920)
Browse files Browse the repository at this point in the history
- fix for AttentionModule
- fix SARDecoder forward step
  • Loading branch information
felixdittrich92 authored May 20, 2022
1 parent 1ab6fe1 commit 607aaaf
Showing 1 changed file with 97 additions and 77 deletions.
174 changes: 97 additions & 77 deletions doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

default_cfgs: Dict[str, Dict[str, Any]] = {
'sar_resnet31': {
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (3, 32, 128),
'vocab': VOCABS['legacy_french'],
'vocab': VOCABS['french'],
'url': None,
},
}
Expand All @@ -50,28 +50,35 @@ class AttentionModule(nn.Module):

def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None:
super().__init__()
self.feat_conv = nn.Conv2d(feat_chans, attention_units, 3, padding=1)
self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1)
# No need to add another bias since both tensors are summed together
self.state_conv = nn.Linear(state_chans, attention_units, bias=False)
self.attention_projector = nn.Linear(attention_units, 1, bias=False)
self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False)
self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False)

def forward(self, features: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor:
# shape (N, C, H, W) -> (N, attention_units, H, W)
def forward(
self,
features: torch.Tensor, # (N, C, H, W)
hidden_state: torch.Tensor, # (N, C)
) -> torch.Tensor:

H_f, W_f = features.shape[2:]

# (N, feat_chans, H, W) --> (N, attention_units, H, W)
feat_projection = self.feat_conv(features)
# shape (N, L, rnn_units) -> (N, L, attention_units)
state_projection = self.state_conv(hidden_state).unsqueeze(-1).unsqueeze(-1)
# (N, L, attention_units, H, W)
projection = torch.tanh(feat_projection.unsqueeze(1) + state_projection)
# (N, L, H, W, 1)
attention = self.attention_projector(projection.permute(0, 1, 3, 4, 2))
# shape (N, L, H, W, 1) -> (N, L, H * W)
attention = torch.flatten(attention, 2)
attention = torch.softmax(attention, -1)
# shape (N, L, H * W) -> (N, L, 1, H, W)
attention = attention.reshape(-1, hidden_state.shape[1], features.shape[-2], features.shape[-1])

# (N, L, C)
return (features.unsqueeze(1) * attention.unsqueeze(2)).sum(dim=(3, 4))
# (N, state_chans, 1, 1) --> (N, attention_units, 1, 1)
hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1)
state_projection = self.state_conv(hidden_state)
state_projection = state_projection.expand(-1, -1, H_f, W_f)
# (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f)
attention_weights = torch.tanh(feat_projection + state_projection)
# (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f)
attention_weights = self.attention_projector(attention_weights)
B, C, H, W = attention_weights.size()

# (N, H, W) --> (N, 1, H, W)
attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W)
# fuse features and attention weights (N, C)
return (features * attention_weights).sum(dim=(2, 3))


class SARDecoder(nn.Module):
Expand All @@ -83,7 +90,6 @@ class SARDecoder(nn.Module):
vocab_size: number of classes in the model alphabet
embedding_units: number of hidden embedding units
attention_units: number of hidden attention units
num_decoder_layers: number of LSTM layers to stack
"""
def __init__(
Expand All @@ -93,65 +99,68 @@ def __init__(
vocab_size: int,
embedding_units: int,
attention_units: int,
num_decoder_layers: int = 2,
feat_chans: int = 512,
dropout_prob: float = 0.,
) -> None:

super().__init__()
self.vocab_size = vocab_size
self.rnn = nn.LSTM(rnn_units, rnn_units, 2, batch_first=True, dropout=dropout_prob)
self.embed = nn.Embedding(self.vocab_size + 1, embedding_units)
self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units)
self.output_dense = nn.Linear(2 * rnn_units, vocab_size + 1)
self.max_length = max_length

self.embed = nn.Linear(self.vocab_size + 1, embedding_units)
self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1)
self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units)
self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units)
self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1)
self.dropout = nn.Dropout(dropout_prob)

def forward(
self,
features: torch.Tensor, # (N, C, H, W)
holistic: torch.Tensor, # (N, C)
gt: Optional[torch.Tensor] = None, # (N, L)
) -> torch.Tensor:

if gt is None:
# Initialize with the index of virtual START symbol (placed after <eos> so that the one-hot is only zeros)
symbol = torch.zeros(features.shape[0], device=features.device, dtype=torch.long)
# (N, embedding_units)
symbol = self.embed(symbol)
# (N, L, embedding_units)
symbol = symbol.unsqueeze(1).expand(-1, self.max_length + 1, -1)
else:
# (N, L) --> (N, L, embedding_units)
symbol = self.embed(gt)
# (N, L + 1, embedding_units)
symbol = torch.cat((holistic.unsqueeze(1), symbol), dim=1)

# (N, L, vocab_size + 1)
char_logits = torch.zeros(
(features.shape[0], self.max_length + 1, self.vocab_size + 1),
device=features.device,
dtype=features.dtype,
)

decoding_iter = self.max_length + 1 if gt is None else 1
for t in range(decoding_iter):
# (N, L + 1, rnn_units)
logits = self.rnn(symbol)[0]
# (N, L + 1, C)
glimpse = self.attention_module(features, logits)
# (N, L + 1, 2 * rnn_units)
logits = torch.cat((logits, glimpse), -1)
# (N, L + 1, vocab_size + 1)
decoded = self.output_dense(logits)
if gt is None:
char_logits[:, t] = decoded[:, t + 1]
if t < decoding_iter - 1:
# update symbol with predicted logits for t + 1 step
symbol[:, t + 2] = self.embed(decoded[:, t + 1].argmax(-1))
if gt is not None:
gt_embedding = self.embed_tgt(gt)

logits_list: List[torch.Tensor] = []

for t in range(self.max_length + 1): # 32
if t == 0:
# step to init the first states of the LSTMCell
hidden_state_init = cell_state_init = \
torch.zeros(features.size(0), features.size(1), device=features.device)
hidden_state, cell_state = hidden_state_init, cell_state_init
prev_symbol = holistic
elif t == 1:
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
# (N, vocab_size + 1) --> (N, embedding_units)
prev_symbol = torch.zeros(features.size(0), self.vocab_size + 1, device=features.device)
prev_symbol = self.embed(prev_symbol)
else:
char_logits = decoded[:, 1:]

return char_logits
if gt is not None:
#(N, embedding_units) -2 because of <bos> and <eos> (same)
prev_symbol = self.embed(gt_embedding[:, t - 2])
else:
# -1 to start at timestep where prev_symbol was initialized
index = logits_list[t - 1].argmax(-1)
# (1, embedding_units)
prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)

# (N, C), (N, C) take the last hidden state and cell state from current timestep
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state))
# (N, C, H, W), (N, C) --> (N, C)
glimpse = self.attention_module(features, hidden_state)
# (N, C), (N, C) --> (N, 2 * C)
logits = torch.cat([hidden_state, glimpse], dim=1)
logits = self.dropout(logits)
# (N, vocab_size + 1)
logits_list.append(self.output_dense(logits))

# (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
return torch.stack(logits_list[1:]).permute(1, 0, 2)


class SAR(nn.Module, RecognitionModel):
Expand All @@ -165,7 +174,6 @@ class SAR(nn.Module, RecognitionModel):
embedding_units: number of embedding units
attention_units: number of hidden units in attention module
max_length: maximum word length handled by the model
num_decoders: number of LSTM to stack in decoder layer
dropout_prob: dropout probability of the encoder LSTM
cfg: default setup dict of the model
"""
Expand All @@ -176,9 +184,8 @@ def __init__(
vocab: str,
rnn_units: int = 512,
embedding_units: int = 512,
attention_units: int = 512,
attention_units: int = 64,
max_length: int = 30,
num_decoders: int = 2,
dropout_prob: float = 0.,
input_shape: Tuple[int, int, int] = (3, 32, 128),
cfg: Optional[Dict[str, Any]] = None,
Expand All @@ -200,10 +207,8 @@ def __init__(
self.feat_extractor.train()

self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob)

self.decoder = SARDecoder(
rnn_units, max_length, len(vocab), embedding_units, attention_units, num_decoders, out_shape[1],
)
self.decoder = SARDecoder(rnn_units, self.max_length, len(self.vocab),
embedding_units, attention_units, dropout_prob=dropout_prob)

self.postprocessor = SARPostProcessor(vocab=vocab)

Expand All @@ -217,9 +222,10 @@ def forward(

features = self.feat_extractor(x)['features']
# Vertical max pooling --> (N, C, W)
pooled_features = features.max(dim=-2).values
pooled_features = F.max_pool2d(features, kernel_size=(features.shape[2], 1), stride=(1, 1))
pooled_features = pooled_features.squeeze(2)
# (N, W, C)
pooled_features = pooled_features.permute((0, 2, 1))
pooled_features = pooled_features.permute(0, 2, 1).contiguous()
# (N, C)
encoded = self.encoder(pooled_features)
if target is not None:
Expand Down Expand Up @@ -301,6 +307,7 @@ def _sar(
backbone_fn: Callable[[bool], nn.Module],
layer: str,
pretrained_backbone: bool = True,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any
) -> SAR:

Expand All @@ -323,7 +330,10 @@ def _sar(
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]['url'])
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if _cfg['vocab'] != default_cfgs[arch]['vocab'] else None
load_pretrained_params(model, default_cfgs[arch]['url'], ignore_keys=_ignore_keys)

return model

Expand All @@ -345,4 +355,14 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
text recognition architecture
"""

return _sar('sar_resnet31', pretrained, resnet31, '10', **kwargs)
return _sar(
'sar_resnet31',
pretrained,
resnet31,
'10',
ignore_keys=[
'decoder.embed.weight', 'decoder.embed_tgt.weight',
'decoder.output_dense.weight', 'decoder.output_dense.bias'
],
**kwargs,
)

0 comments on commit 607aaaf

Please sign in to comment.