Skip to content

Commit

Permalink
fix warning trigger for embed_positions when loading xglm (huggingfac…
Browse files Browse the repository at this point in the history
…e#25798)

* fix warning triggering for xglm.embed_positions

* Make TF variable a tf.constant to match (and fix some spelling)

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
  • Loading branch information
2 people authored and parambharat committed Sep 26, 2023
1 parent 7edd694 commit f7cdb2f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/xglm/modeling_tf_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
LARGE_NEGATIVE = -1e8


def create_sinusiodal_positions(num_positions: int, embedding_dim: int, padding_idx: Optional[int]) -> tf.Tensor:
def create_sinusoidal_positions(num_positions: int, embedding_dim: int, padding_idx: Optional[int]) -> tf.Tensor:
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb)
Expand All @@ -83,7 +83,7 @@ def create_sinusiodal_positions(num_positions: int, embedding_dim: int, padding_
)
emb *= _padding_mask

return tf.Variable(emb, trainable=False, name="model.embed_positions.weights")
return tf.constant(emb, name="embed_positions")


def _create_position_ids_from_input_ids(
Expand Down Expand Up @@ -438,7 +438,7 @@ def __init__(
)

self.offset = 2
self._embed_positions_weights = create_sinusiodal_positions(
self._embed_positions_weights = create_sinusoidal_positions(
num_positions=config.max_position_embeddings + self.offset,
embedding_dim=config.d_model,
padding_idx=config.pad_token_id,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xglm/modeling_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Opt
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)

self.register_buffer("weights", emb_weights)
self.register_buffer("weights", emb_weights, persistent=False)

@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
Expand Down

0 comments on commit f7cdb2f

Please sign in to comment.