Skip to content

Commit

Permalink
tests: skip attention-related parameterize when attn_layer is 0 (#3784)
Browse files Browse the repository at this point in the history
The tests make no sense in this case.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Improved test coverage by adding an optional `temperature` parameter
to the attention layer tests.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored May 16, 2024
1 parent d0d596b commit d62a41f
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from typing import (
Any,
Optional,
Tuple,
)

Expand Down Expand Up @@ -111,6 +112,15 @@ def data(self) -> dict:
"seed": 1145141919810,
}

def is_meaningless_zero_attention_layer_tests(
self,
attn_layer: int,
attn_dotr: bool,
normalize: bool,
temperature: Optional[float],
) -> bool:
return attn_layer == 0 and (attn_dotr or normalize or temperature is not None)

@property
def skip_pt(self) -> bool:
(
Expand All @@ -133,7 +143,12 @@ def skip_pt(self) -> bool:
precision,
use_econf_tebd,
) = self.param
return CommonTest.skip_pt
return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_dp(self) -> bool:
Expand All @@ -157,7 +172,12 @@ def skip_dp(self) -> bool:
precision,
use_econf_tebd,
) = self.param
return CommonTest.skip_pt
return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_tf(self) -> bool:
Expand All @@ -183,12 +203,21 @@ def skip_tf(self) -> bool:
) = self.param
# TODO (excluded_types != [] and attn_layer > 0) need fix
return (
env_protection != 0.0
or smooth_type_embedding
or not normalize
or temperature != 1.0
or (excluded_types != [] and attn_layer > 0)
or (type_one_side and tebd_input_mode == "strip") # not consistent yet
CommonTest.skip_tf
or (
env_protection != 0.0
or smooth_type_embedding
or not normalize
or temperature != 1.0
or (excluded_types != [] and attn_layer > 0)
or (type_one_side and tebd_input_mode == "strip") # not consistent yet
)
or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)
)

tf_class = DescrptDPA1TF
Expand Down

0 comments on commit d62a41f

Please sign in to comment.