Skip to content

Commit

Permalink
Correct linting
Browse files Browse the repository at this point in the history
Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>
  • Loading branch information
thanawan-atc committed Jul 14, 2023
1 parent b706aaf commit f5093c9
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import yaml
from accelerate import Accelerator, notebook_launcher
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Pooling, Normalize, Transformer
from sentence_transformers.models import Normalize, Pooling, Transformer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import TrainingArguments, get_linear_schedule_with_warmup
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def make_model_config_json(
:param pooling_mode: Optional, the pooling mode of the model. If None, get pooling_mode
from the pre-trained hugging-face model object.
:type pooling_mode: string
:param normalize_result: Optional, whether to normalize the result of the model. If None, check from the pre-trained
:param normalize_result: Optional, whether to normalize the result of the model. If None, check from the pre-trained
hugging-face model object. If not found, do not include it.
:type normalize_result: bool
:param all_config:
Expand All @@ -1029,28 +1029,34 @@ def make_model_config_json(
model_name = self.model_id

# if user input model_type/embedding_dimension/pooling_mode, it will skip this step.

model = SentenceTransformer(self.model_id)
if model_type is None:
if len(model._modules) >= 1 and isinstance(model._modules['0'], Transformer):
if len(model._modules) >= 1 and isinstance(
model._modules["0"], Transformer
):
try:
model_type = model._modules['0'].auto_model.__class__.__name__
model_type = model_type.lower().rstrip('model')
except:
raise Exception("Raised exception while getting model_type")
model_type = model._modules["0"].auto_model.__class__.__name__
model_type = model_type.lower().rstrip("model")
except Exception as e:
raise Exception(f"Raised exception while getting model_type: {e}")

Check warning on line 1042 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1041-L1042

Added lines #L1041 - L1042 were not covered by tests

if embedding_dimension is None:
try:
embedding_dimension = model.get_sentence_embedding_dimension()
except:
raise Exception("Raised exception while calling get_sentence_embedding_dimension()")

except Exception as e:
raise Exception(

Check warning on line 1048 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1047-L1048

Added lines #L1047 - L1048 were not covered by tests
f"Raised exception while calling get_sentence_embedding_dimension(): {e}"
)

if pooling_mode is None:
if len(model._modules) >= 2 and isinstance(model._modules['1'], Pooling):
if len(model._modules) >= 2 and isinstance(model._modules["1"], Pooling):
try:
pooling_mode = model._modules['1'].get_pooling_mode_str().upper()
except:
raise Exception("Raised exception while calling get_pooling_mode_str()")
pooling_mode = model._modules["1"].get_pooling_mode_str().upper()
except Exception as e:
raise Exception(

Check warning on line 1057 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1056-L1057

Added lines #L1056 - L1057 were not covered by tests
f"Raised exception while calling get_pooling_mode_str(): {e}"
)

if all_config is None:
if not os.path.exists(config_json_file_path):
Expand All @@ -1062,7 +1068,7 @@ def make_model_config_json(
)
)
try:
with open(config_json_file_path) as f:
with open(config_json_file_path) as f:
if verbose:
print("reading config file from: " + config_json_file_path)
config_content = json.load(f)
Expand All @@ -1075,8 +1081,7 @@ def make_model_config_json(
". Please check the config.json ",
"file in the path.",
)



model_config_content = {
"name": model_name,
"version": version_number,
Expand All @@ -1094,7 +1099,7 @@ def make_model_config_json(
if normalize_result is not None:
model_config_content["model_config"]["normalize_result"] = normalize_result
else:
if len(model._modules) >= 3 and isinstance(model._modules['2'], Normalize):
if len(model._modules) >= 3 and isinstance(model._modules["2"], Normalize):
model_config_content["model_config"]["normalize_result"] = True

if verbose:
Expand Down

0 comments on commit f5093c9

Please sign in to comment.