Skip to content

Commit

Permalink
bug fix huggingface#1
Browse files Browse the repository at this point in the history
  • Loading branch information
KobeKnowles committed Jun 8, 2022
1 parent e633c27 commit cff5483
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def call(
# input_hidden_states = tf.math.sigmoid(input_hidden_states) * hidden_states
#else:
# input_hidden_states = input_hidden_states # no sigmoid is performed here as no gating occurs. Treats it as if additional layers have been performed.
if gate and config.nm_gating: input_hidden_states = tf.math.sigmoid(input_hidden_states) * hidden_states
if gate and self.config.nm_gating: input_hidden_states = tf.math.sigmoid(input_hidden_states) * hidden_states

layer_outputs = layer_module(
hidden_states=input_hidden_states,
Expand All @@ -610,7 +610,7 @@ def call(
# after the, say 3rd layer, for example, we apply neuromodulation gating to the output hidden state.

# in the config class there is a clause where they can't be equal so if elif... is correct.
if config.gating_block_start_position == i+1: # layers start at 1 not 0; hence, why the +1.
if self.config.gating_block_start_position == i+1: # layers start at 1 not 0; hence, why the +1.
dict_start = self.gating_block_iterate(type_="start", gating_block=self.gating_block_start,
hidden_states=hidden_states, attention_mask=attention_mask,
head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states,
Expand All @@ -622,7 +622,7 @@ def call(
raise Exception(f"dict_start should contain the key last_hidden_state_gating_block_start"
f" but doesn't")
hidden_states_after_gating_start = dict_start["last_hidden_state_gating_block_start"]
elif config.gating_block_middle_position == i+1:
elif self.config.gating_block_middle_position == i+1:
dict_middle = self.gating_block_iterate(type_="middle", gating_block=self.gating_block_middle,
hidden_states=hidden_states, attention_mask=attention_mask,
head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states,
Expand All @@ -634,7 +634,7 @@ def call(
raise Exception(f"dict_start should contain the key last_hidden_state_gating_block_middle"
f" but doesn't")
hidden_states_after_gating_middle = dict_middle["last_hidden_state_gating_block_middle"]
elif config.gating_block_end_position == i+1:
elif self.config.gating_block_end_position == i+1:
dict_end = self.gating_block_iterate(type_="end", gating_block=self.gating_block_end,
hidden_states=hidden_states, attention_mask=attention_mask,
head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states,
Expand Down

0 comments on commit cff5483

Please sign in to comment.