From cdb6ae01d0b2a13a43d76b03f624576ac370a33e Mon Sep 17 00:00:00 2001 From: Le Dong <74060032+ledong0110@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:12:40 +0700 Subject: [PATCH] fix: format code convention (#28) --- src/melt/__main__.py | 2 +- tests/test_execution.py | 30 ++++++++++++++++++++++++++---- tests/test_wrapper.py | 26 ++++++++++++++++++++++++-- 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/melt/__main__.py b/src/melt/__main__.py index e522cda..3695aeb 100644 --- a/src/melt/__main__.py +++ b/src/melt/__main__.py @@ -1,5 +1,6 @@ import spacy import nltk +from .cli import main nltk.download('punkt_tab') try: @@ -12,6 +13,5 @@ from spacy.cli import download download("en_core_web_sm") -from .cli import main main() diff --git a/tests/test_execution.py b/tests/test_execution.py index 5060b03..ac64098 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -1,6 +1,7 @@ import subprocess import unittest + class TestTasks(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestTasks, self).__init__(*args, **kwargs) @@ -12,7 +13,27 @@ def __init__(self, *args, **kwargs): self.smoke_test = True # Set the smoke_test argument to True def run_melt_command(self, dataset_name): - result = subprocess.run(["melt", "--wtype", self.wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True) + result = subprocess.run( + [ + "melt", + "--wtype", + self.wrapper_type, + "--model_name", + self.model_name, + "--dataset_name", + dataset_name, + "--ptemplate", + self.ptemplate, + "--lang", + self.lang, + "--seed", + str(self.seed), + "--smoke_test", + str(self.smoke_test), + ], + capture_output=True, + text=True, + ) self.assertEqual(result.returncode, 0) def test_sentiment_analysis(self): @@ -29,7 +50,7 @@ def test_toxic_detection(self): # Test toxic detection task dataset_name = "ViHSD" self.run_melt_command(dataset_name) - + def test_reasoning(self): # Test reasoning task dataset_name = "synthetic_natural_azr" @@ -70,5 +91,6 @@ def test_information_retrieval(self): dataset_name = "mmarco" self.run_melt_command(dataset_name) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index e9f956f..73b1450 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -1,6 +1,7 @@ import subprocess import unittest + class TestWrapper(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestWrapper, self).__init__(*args, **kwargs) @@ -11,7 +12,27 @@ def __init__(self, *args, **kwargs): self.smoke_test = True # Set the smoke_test argument to True def run_melt_command(self, dataset_name, wrapper_type): - result = subprocess.run(["melt", "--wtype", wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True) + result = subprocess.run( + [ + "melt", + "--wtype", + wrapper_type, + "--model_name", + self.model_name, + "--dataset_name", + dataset_name, + "--ptemplate", + self.ptemplate, + "--lang", + self.lang, + "--seed", + str(self.seed), + "--smoke_test", + str(self.smoke_test), + ], + capture_output=True, + text=True, + ) self.assertEqual(result.returncode, 0) def test_wrapper_hf(self): @@ -39,5 +60,6 @@ def test_wrapper_vllm(self): dataset_name = "zalo_e2eqa" self.run_melt_command(dataset_name, "vllm") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()