Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA-trainable DeBERTa v2 #18546

Merged
merged 5 commits into from
Aug 10, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add different code paths for gpu and tpu
gante committed Aug 9, 2022
commit 64aad9a950bdf621dce67eb9c3f6efcfc0cf7b12
21 changes: 15 additions & 6 deletions src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
Original file line number Diff line number Diff line change
@@ -519,12 +519,21 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
def take_along_axis(x, indices):
# Only a valid port of np.take_along_axis when the gather axis is -1

# [B, S, P] -> [B, S, P, D]
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)

# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
# [B, S, P] -> [B, S, P, D]
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)

# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)

# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
else:
flat_x = tf.reshape(x, (-1, x.shape[-1]))
flat_indices = tf.reshape(indices, (-1, indices.shape[-1]))
gathered = tf.gather(flat_x, flat_indices, batch_dims=1)
gathered = tf.reshape(gathered, shape_list(indices))

return gathered