Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Nov 26, 2021
1 parent e7a12e7 commit 16ed621
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 12 deletions.
10 changes: 0 additions & 10 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,6 @@ def forward(self, input, mask):
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np

print("is_kernel_available")
print(f"self.self.scaled_masked_softmax_fusion = {self.scaled_masked_softmax_fusion}")
print(f"self.input_in_float16 = {self.input_in_float16}")
print(f"mask is not None = {mask is not None}")
print(f"sq = {sq}")
print(f"sk = {sk}")
print(f"attn_batches = {attn_batches}")
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
Expand All @@ -154,16 +147,13 @@ def is_kernel_available(self, mask, b, np, sq, sk):
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
print("wtf")
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
print(f"causal {attn_batches} / {batch_per_block}")
if attn_batches % batch_per_block == 0:
return True
else:
print(f"non causal {sq} / {batch_per_block}")
if sq % batch_per_block == 0:
return True
return False
Expand Down
2 changes: 0 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def test_prefix_lm_wo_reset_attention_mask(self):
self.assertNotIn("Using torch softmax", cs.out)
#TODO: Check all invariants



def test_gpt_rotary_embeddings(self):
"""Test rotary embeddings"""
command_args = get_default_args()
Expand Down

0 comments on commit 16ed621

Please sign in to comment.