Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] PARSeq tensorflow fixes #1228

Merged
merged 7 commits into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 79 additions & 82 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self, vocab_size: int, d_model: int):
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.d_model = d_model

def call(self, x: tf.Tensor) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x)
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x, **kwargs)


class PARSeqDecoder(layers.Layer):
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
max_length: int = 32, # different from paper
dropout_prob: float = 0.1,
dec_num_heads: int = 12,
dec_ff_dim: int = 2048,
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
dec_ffd_ratio: int = 4,
input_shape: Tuple[int, int, int] = (32, 128, 3),
exportable: bool = False,
Expand Down Expand Up @@ -209,10 +209,7 @@ def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
combined = tf.tensor_scatter_nd_update(
combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
)
# we pad to max length with eos idx to fit the mask generation
return tf.pad(
combined, [[0, 0], [0, self.max_length + 1 - tf.shape(combined)[1]]], constant_values=max_num_chars + 2
) # (num_perms, self.max_length + 1)
return combined

@tf.function
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
Expand All @@ -232,7 +229,6 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple
mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
)
target_mask = mask[1:, :-1]

return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)

@tf.function
Expand All @@ -246,110 +242,78 @@ def decode(
) -> tf.Tensor:
batch_size, sequence_length = target.shape
# apply positional information to the target sequence excluding the SOS token
null_ctx = self.embed(target[:, :1])
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:])
null_ctx = self.embed(target[:, :1], **kwargs)
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
if target_query is None:
target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
target_query = self.dropout(target_query, **kwargs)
return self.decoder(target_query, content, memory, target_mask, **kwargs)

@tf.function
def decode_autoregressive(self, features: tf.Tensor) -> tf.Tensor:
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
"""Generate predictions for the given features."""
# Padding symbol + SOS at the beginning
max_length = max_len if max_len is not None else self.max_length
max_length = min(max_length, self.max_length) + 1
b = tf.shape(features)[0]
ys = tf.fill(dims=(b, self.max_length), value=self.vocab_size + 2)
# Padding symbol + SOS at the beginning
ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
ys = tf.concat([start_vector, ys], axis=-1)
pos_queries = tf.tile(self.pos_queries[:, : self.max_length + 1], [b, 1, 1])
query_mask = tf.cast(
tf.linalg.band_part(tf.ones((self.max_length + 1, self.max_length + 1)), -1, 0), dtype=tf.bool
)
pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)

pos_logits = []
for i in range(self.max_length):
for i in range(max_length):
# Decode one token at a time without providing information about the future tokens
tgt_out = self.decode(
ys[:, : i + 1],
features,
query_mask[i : i + 1, : i + 1],
target_query=pos_queries[:, i : i + 1],
**kwargs,
)
pos_prob = self.head(tgt_out)
pos_logits.append(pos_prob)

if i + 1 < self.max_length:
if i + 1 < max_length:
# update ys with the next token
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(self.max_length), indexing="ij")
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
ys = tf.tensor_scatter_nd_update(
ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
)

# Stop decoding if all sequences have reached the EOS token
# We need to check it on True to be compatible with ONNX
if tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True:
if (
max_len is None
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
):
break

logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)

# One refine iteration
# Update query mask
query_mask = tf.cast(1 - tf.linalg.diag(tf.ones(self.max_length, dtype=tf.int32), k=-1), dtype=tf.bool)
diag_matrix = tf.eye(max_length)
diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)

sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
# Create padding mask for refined target input maskes all behind EOS token as False
# (N, 1, 1, max_length)
target_pad_mask = tf.cumsum(tf.cast(tf.equal(ys, self.vocab_size), dtype=tf.int32), axis=1, reverse=False)
target_pad_mask = tf.logical_not(tf.cast(target_pad_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool))
mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)

mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)

return logits # (N, max_length, vocab_size + 1)

@tf.function
def decode_non_autoregressive(self, features: tf.Tensor) -> tf.Tensor:
"""Decode the given features at once"""
pos_queries = tf.tile(self.pos_queries[:, : self.max_length + 1], [tf.shape(features)[0], 1, 1])
ys = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
return self.head(self.decode(ys, features, target_query=pos_queries))[:, : self.max_length]

@staticmethod
def compute_loss(
model_output: tf.Tensor,
gt: tf.Tensor,
seq_len: List[int],
) -> tf.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.

Args:
model_output: predicted logits of the model
gt: the encoded tensor with gt labels
seq_len: lengths of each gt word inside the batch

Returns:
The loss of the model on the batch
"""
# Input length : number of steps
input_len = tf.shape(model_output)[1]
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = tf.cast(seq_len, tf.int32) + 1
# One-hot gt labels
oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>. Delete last logit of the model output.
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :])
# Compute mask
mask_values = tf.zeros_like(cce)
mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well
masked_loss = tf.where(mask_2d, cce, mask_values)
ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))

return tf.expand_dims(ce_loss, axis=1)

def call(
self,
x: tf.Tensor,
Expand All @@ -362,36 +326,69 @@ def call(
# remove cls token
features = features[:, 1:, :]

if target is not None:
gt, seq_len = self.build_target(target)
seq_len = tf.cast(seq_len, tf.int32)

if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training")

if target is not None:
gt, seq_len = self.build_target(target)
seq_len = tf.cast(seq_len, tf.int32)
gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)

if kwargs.get("training", False):
# Generate permutations of the target sequences
tgt_perms = self.generate_permutations(seq_len)

gt_in = gt[:, :-1] # remove EOS token from longest target sequence
gt_out = gt[:, 1:] # remove SOS token

# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = ((gt != self.vocab_size + 2) | (gt != self.vocab_size))[:, tf.newaxis, tf.newaxis, :]

for perm in tgt_perms:
# Generate attention masks for the permutations
_, target_mask = self.generate_permutations_attention_masks(perm)
# combine target padding mask and query mask
mask = tf.math.logical_and(target_mask, padding_mask)
logits = self.head(self.decode(gt, features, mask))
padding_mask = tf.math.logical_and(
tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
)
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)

loss = tf.constant(0.0)
loss_numel = tf.constant(0.0)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks to (N, 1, seq_len, seq_len)
mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))

logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt_out, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss += n * tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
loss_numel += n

# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
if i == 1:
gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))

loss /= loss_numel

else:
# eval step - use non-autoregressive decoding while training evaluation
logits = self.decode_non_autoregressive(features)
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len, **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
else:
logits = self.decode_autoregressive(features)
logits = self.decode_autoregressive(features, **kwargs)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand All @@ -406,7 +403,7 @@ def call(
out["preds"] = self.postprocessor(logits)

if target is not None:
out["loss"] = self.compute_loss(logits, gt, seq_len)
out["loss"] = loss

return out

Expand Down