Skip to content

Commit

Permalink
fix: format code convention (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
ledong0110 committed Sep 10, 2024
1 parent 533bf11 commit cdb6ae0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/melt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import spacy
import nltk
from .cli import main

nltk.download('punkt_tab')
try:
Expand All @@ -12,6 +13,5 @@
from spacy.cli import download

download("en_core_web_sm")
from .cli import main

main()
30 changes: 26 additions & 4 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import unittest


class TestTasks(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestTasks, self).__init__(*args, **kwargs)
Expand All @@ -12,7 +13,27 @@ def __init__(self, *args, **kwargs):
self.smoke_test = True # Set the smoke_test argument to True

def run_melt_command(self, dataset_name):
result = subprocess.run(["melt", "--wtype", self.wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True)
result = subprocess.run(
[
"melt",
"--wtype",
self.wrapper_type,
"--model_name",
self.model_name,
"--dataset_name",
dataset_name,
"--ptemplate",
self.ptemplate,
"--lang",
self.lang,
"--seed",
str(self.seed),
"--smoke_test",
str(self.smoke_test),
],
capture_output=True,
text=True,
)
self.assertEqual(result.returncode, 0)

def test_sentiment_analysis(self):
Expand All @@ -29,7 +50,7 @@ def test_toxic_detection(self):
# Test toxic detection task
dataset_name = "ViHSD"
self.run_melt_command(dataset_name)

def test_reasoning(self):
# Test reasoning task
dataset_name = "synthetic_natural_azr"
Expand Down Expand Up @@ -70,5 +91,6 @@ def test_information_retrieval(self):
dataset_name = "mmarco"
self.run_melt_command(dataset_name)

if __name__ == '__main__':
unittest.main()

if __name__ == "__main__":
unittest.main()
26 changes: 24 additions & 2 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import unittest


class TestWrapper(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestWrapper, self).__init__(*args, **kwargs)
Expand All @@ -11,7 +12,27 @@ def __init__(self, *args, **kwargs):
self.smoke_test = True # Set the smoke_test argument to True

def run_melt_command(self, dataset_name, wrapper_type):
result = subprocess.run(["melt", "--wtype", wrapper_type, "--model_name", self.model_name, "--dataset_name", dataset_name, "--ptemplate", self.ptemplate, "--lang", self.lang, "--seed", str(self.seed), "--smoke_test", str(self.smoke_test)], capture_output=True, text=True)
result = subprocess.run(
[
"melt",
"--wtype",
wrapper_type,
"--model_name",
self.model_name,
"--dataset_name",
dataset_name,
"--ptemplate",
self.ptemplate,
"--lang",
self.lang,
"--seed",
str(self.seed),
"--smoke_test",
str(self.smoke_test),
],
capture_output=True,
text=True,
)
self.assertEqual(result.returncode, 0)

def test_wrapper_hf(self):
Expand Down Expand Up @@ -39,5 +60,6 @@ def test_wrapper_vllm(self):
dataset_name = "zalo_e2eqa"
self.run_melt_command(dataset_name, "vllm")

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main()

0 comments on commit cdb6ae0

Please sign in to comment.