Skip to content

Commit

Permalink
[minor] test block factory on device, faster unit tests (facebookrese…
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored May 12, 2021
1 parent 867b62e commit 110889b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
)

BATCH = 5
SEQ = 1024
MODEL = 384
SEQ = 128
MODEL = 96
GLOBAL_ATTENTION_RATIO = (
_DENSITY_THRESHOLD * 0.9
) # Make sure that we test the sparse implementation, no matter the threshold
Expand Down
25 changes: 18 additions & 7 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
)

BATCH = 20
SEQ = 512
MODEL = 384
SEQ = 128
MODEL = 96
DROPOUT = 0.5
GLOBAL_ATTENTION_RATIO = 0.1 # 10% of the tokens have a global view
DEVICES = (
[torch.device("cpu")]
if not torch.cuda.is_available()
else [
torch.device("cuda")
] # save a bit on CI for now, we have seperate cpu and gpu jobs
)


@pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
Expand All @@ -30,6 +37,7 @@
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
@pytest.mark.parametrize("device", DEVICES)
def test_xformer_encoder_block(
attention_name: str,
feedforward_name: str,
Expand All @@ -38,6 +46,7 @@ def test_xformer_encoder_block(
residual_dropout: float,
causal: bool,
activation: Activation,
device: torch.device,
):

attention_config = {
Expand Down Expand Up @@ -76,10 +85,10 @@ def test_xformer_encoder_block(
)

# Test that the whole block can be instantiated
block = xFormerEncoderBlock.from_config(block_config)
block = xFormerEncoderBlock.from_config(block_config).to(device)

# Check that the dimensions make sense, to a FW pass
inputs = torch.rand(BATCH, SEQ, MODEL)
inputs = torch.rand(BATCH, SEQ, MODEL, device=device)
_ = block(inputs)


Expand All @@ -90,6 +99,7 @@ def test_xformer_encoder_block(
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
@pytest.mark.parametrize("device", DEVICES)
def test_xformer_decoder_block(
attention_name: str,
feedforward_name: str,
Expand All @@ -98,6 +108,7 @@ def test_xformer_decoder_block(
residual_dropout: float,
causal: bool,
activation: Activation,
device: torch.device,
):

attention_config = {
Expand Down Expand Up @@ -150,11 +161,11 @@ def test_xformer_decoder_block(
)

# Test that the whole block can be instantiated
encoder_block = xFormerEncoderBlock.from_config(encoder_block_config)
decoder_block = xFormerDecoderBlock.from_config(decoder_block_config)
encoder_block = xFormerEncoderBlock.from_config(encoder_block_config).to(device)
decoder_block = xFormerDecoderBlock.from_config(decoder_block_config).to(device)

# Check that the dimensions make sense, to a FW pass
inputs = torch.rand(BATCH, SEQ, MODEL)
inputs = torch.rand(BATCH, SEQ, MODEL, device=device)
encoded = encoder_block(inputs)
_ = decoder_block(
inputs, encoded
Expand Down

0 comments on commit 110889b

Please sign in to comment.