Skip to content

Commit

Permalink
Format files (#541)
Browse files Browse the repository at this point in the history
## Summary

Format

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

Signed-off-by: Austin Liu <austin362667@gmail.com>
  • Loading branch information
austin362667 authored Jan 27, 2025
1 parent 1441a31 commit 74f4ad8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
19 changes: 9 additions & 10 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import functools
import os

import pytest
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

from liger_kernel.transformers import apply_liger_kernel_to_mllama
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
from test.utils import FAKE_CONFIGS_PATH
from test.utils import UNTOKENIZED_DATASET_PATH
from test.utils import MiniModelConfig
Expand All @@ -13,16 +22,6 @@
from test.utils import supports_bfloat16
from test.utils import train_bpe_tokenizer

import pytest
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

from liger_kernel.transformers import apply_liger_kernel_to_mllama
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl

try:
# Qwen2-VL is only available in transformers>=4.45.0
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
Expand Down
7 changes: 4 additions & 3 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from test.transformers.test_cross_entropy import CrossEntropyWithZLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed
from typing import Optional

import pytest
import torch

from test.transformers.test_cross_entropy import CrossEntropyWithZLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed

from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
Expand Down
8 changes: 4 additions & 4 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16

import pytest
import torch
import torch.nn as nn

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16

from liger_kernel.ops.rms_norm import LigerRMSNormFunction
from liger_kernel.transformers.functional import liger_rms_norm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
Expand Down
4 changes: 3 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,9 @@ def get_batch_loss_metrics(
**loss_kwargs,
):
"""Compute the loss metrics for the given batch of inputs for train or test."""
forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, preference_labels, nll_target)
forward_output = self.concatenated_forward(
_input, weight, target, bias, average_log_prob, preference_labels, nll_target
)
(
policy_chosen_logps,
policy_rejected_logps,
Expand Down

0 comments on commit 74f4ad8

Please sign in to comment.