-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix alibi #222
Fix alibi #222
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,11 +303,27 @@ def forward(self, hidden_states, attention_mask, layer_past=None, | |
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) | ||
|
||
# Raw attention scores. [b * np, sq, sk] | ||
matmul_result = torch.baddbmm( | ||
matmul_result, | ||
query_layer.transpose(0, 1), # [b * np, sq, hn] | ||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] | ||
beta=0.0 if alibi is None else 1.0, alpha=(1.0/self.norm_factor)) | ||
if alibi is None: | ||
matmul_result = torch.baddbmm( | ||
matmul_result, | ||
query_layer.transpose(0, 1), # [b * np, sq, hn] | ||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] | ||
beta=0.0, alpha=(1.0/self.norm_factor)) | ||
else: | ||
if not hasattr(self, "logged_alibi"): | ||
logger.debug("Using Alibi.") | ||
self.logged_alibi = True | ||
|
||
if self.apply_query_key_layer_scaling: | ||
beta = 1.0 / self.layer_number | ||
else: | ||
beta = 1.0 | ||
thomasw21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
matmul_result = torch.baddbmm( | ||
matmul_result, | ||
query_layer.transpose(0, 1), # [b * np, sq, hn] | ||
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] | ||
beta=beta, alpha=(1.0 / self.norm_factor)) | ||
|
||
# change view to [b, np, sq, sk] | ||
attention_scores = matmul_result.view(*output_size) | ||
|
@@ -470,9 +486,19 @@ def __init__(self, init_method, output_layer_init_method, | |
self.mlp = ParallelMLP(init_method, | ||
output_layer_init_method) | ||
|
||
# Alibi | ||
if args.position_embedding_type == PositionEmbeddingType.alibi: | ||
self.alibi = self._build_alibi_tensor(args.seq_length, args.num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. micro batch size doesn't increase during batch size rampup. It's constant. We should care since all our experiment run with batch size rampup, but I would expect it to crash badly if it doesn't match. |
||
if args.params_dtype == torch.float16: | ||
self.alibi = self.alibi.to(torch.float16) | ||
elif args.params_dtype == torch.bfloat16: | ||
self.alibi = self.alibi.to(torch.bfloat16) | ||
else: | ||
self.alibi = None | ||
|
||
def forward(self, hidden_states, attention_mask, | ||
encoder_output=None, enc_dec_attn_mask=None, | ||
layer_past=None, get_key_value=False, alibi=None): | ||
layer_past=None, get_key_value=False): | ||
# hidden_states: [b, s, h] | ||
|
||
# Layer norm at the beginning of the transformer layer. | ||
|
@@ -483,7 +509,7 @@ def forward(self, hidden_states, attention_mask, | |
attention_mask, | ||
layer_past=layer_past, | ||
get_key_value=get_key_value, | ||
alibi=alibi) | ||
alibi=self.alibi) | ||
|
||
if get_key_value: | ||
attention_output, presents = attention_output | ||
|
@@ -561,6 +587,30 @@ def forward(self, hidden_states, attention_mask, | |
|
||
return output | ||
|
||
@staticmethod | ||
def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size): | ||
# Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 | ||
"""Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)""" | ||
|
||
def get_slopes(n): | ||
def get_slopes_power_of_2(n): | ||
start = (2 ** (-2 ** -(math.log2(n) - 3))) | ||
ratio = start | ||
return [start * ratio ** i for i in range(n)] | ||
|
||
if math.log2(n).is_integer(): | ||
return get_slopes_power_of_2(n) | ||
else: | ||
closest_power_of_2 = 2 ** math.floor(math.log2(n)) | ||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ | ||
:n - closest_power_of_2] | ||
|
||
slopes = torch.Tensor(get_slopes(num_attention_heads)) | ||
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( | ||
num_attention_heads, -1, -1) | ||
alibi = alibi.repeat(batch_size, 1, 1) | ||
return alibi | ||
|
||
class ParallelTransformerLayerPipe(ParallelTransformerLayer): | ||
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline. | ||
|
||
|
@@ -600,27 +650,6 @@ def forward(self, inputs, **kwargs): | |
class ParallelTransformer(MegatronModule): | ||
"""Transformer class.""" | ||
|
||
@staticmethod | ||
def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size): | ||
# Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 | ||
"""Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)""" | ||
def get_slopes(n): | ||
def get_slopes_power_of_2(n): | ||
start = (2 ** (-2 ** -(math.log2(n) - 3))) | ||
ratio = start | ||
return [start * ratio ** i for i in range(n)] | ||
|
||
if math.log2(n).is_integer(): | ||
return get_slopes_power_of_2(n) | ||
else: | ||
closest_power_of_2 = 2 ** math.floor(math.log2(n)) | ||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ | ||
:n - closest_power_of_2] | ||
slopes = torch.Tensor(get_slopes(num_attention_heads)) | ||
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(num_attention_heads, -1, -1) | ||
alibi = alibi.repeat(batch_size, 1, 1) | ||
return alibi | ||
|
||
def __init__(self, init_method, output_layer_init_method, | ||
layer_type=LayerType.encoder, | ||
self_attn_mask_type=AttnMaskType.padding, | ||
|
@@ -687,20 +716,11 @@ def build_layer(layer_number): | |
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker | ||
checkpoint = deepspeed.checkpointing.checkpoint | ||
|
||
if args.position_embedding_type == PositionEmbeddingType.alibi: | ||
self.alibi = self._build_alibi_tensor(args.seq_length, args.num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) | ||
if args.params_dtype == torch.float16: | ||
self.alibi = self.alibi.to(torch.float16) | ||
elif args.params_dtype == torch.bfloat16: | ||
self.alibi = self.alibi.to(torch.bfloat16) | ||
else: | ||
self.alibi = None | ||
|
||
def _get_layer(self, layer_number): | ||
return self.layers[layer_number] | ||
|
||
def _checkpointed_forward(self, hidden_states, attention_mask, | ||
encoder_output, enc_dec_attn_mask, alibi=None): | ||
encoder_output, enc_dec_attn_mask): | ||
"""Forward method with activation checkpointing.""" | ||
def custom(start, end): | ||
def custom_forward(*inputs): | ||
|
@@ -710,7 +730,7 @@ def custom_forward(*inputs): | |
enc_dec_attn_mask = inputs[3] | ||
for index in range(start, end): | ||
layer = self._get_layer(index) | ||
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask, alibi=alibi) | ||
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) | ||
return x_ | ||
return custom_forward | ||
|
||
|
@@ -767,8 +787,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, | |
hidden_states = self._checkpointed_forward(hidden_states, | ||
attention_mask, | ||
encoder_output, | ||
enc_dec_attn_mask, | ||
alibi=self.alibi) | ||
enc_dec_attn_mask) | ||
else: | ||
if get_key_value: | ||
presents = [] | ||
|
@@ -782,8 +801,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, | |
encoder_output=encoder_output, | ||
enc_dec_attn_mask=enc_dec_attn_mask, | ||
layer_past=past, | ||
get_key_value=get_key_value, | ||
alibi=self.alibi) | ||
get_key_value=get_key_value) | ||
if get_key_value: | ||
hidden_states, present = hidden_states | ||
presents.append(present) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for information, on my side I was going to propose to isolate the 3 ways to calculate the score according to the positional embedding by creating a method for each positional embedding method and then wrapping these methods in
log_debug_usage
to have a log (as for the activation functions) to detect in the tests.The micro advantage is that it also allows to test rotary and absolute (in all cases) but I don"t mind if you think it's easier to keep it like you did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's true that your version looks nice. I just think it's an incorrect abstraction. There's no reason to group all the positional embeddings together (they get applied in different places, they do different things, they have different constraints despite having the common purpose on given sequential information). One could argue that using a pure causal mask is a position embedding mechanism.
What I was thinking of is abstracting only the alibi function in a seperate function to use the pretty decorator, but I was lazy ^^'
@log_debug_usage(logger, msg)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this statement :) Models that don't have any position embeddings (like sinusoidal or learned or alibi) are actually able to achieve good (but not great) PPL because the causal mask encodes some kind of order.