Skip to content

Commit

Permalink
Merge pull request #20 from aalok-sathe/aalok-sathe/issue19
Browse files Browse the repository at this point in the history
Fix #19: Conflating causal LM and "gpt" model class (and masked LM and 'bert')
  • Loading branch information
aalok-sathe authored Nov 29, 2023
2 parents 24f2fca + b4a8227 commit a159a1a
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions surprisal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,30 +476,52 @@ def from_pretrained(
"""

model_class = model_class or ""
model_string = model_class.lower() + " " + model_id_or_path.lower()

if (
"gpt3" in model_class.lower() + " " + model_id_or_path.lower()
"gpt3" in model_string
or "openai" in model_string
or model_id_or_path.lower() in openai_models_list
):
if "gpt3" in model_class:
logger.warn(
'DEPRECATION WARNING: please use "openai" as the model class. '
'using "gpt3" as the model class will be deprecated in the future.'
)

return OpenAIModel(model_id_or_path, **kwargs)
elif "gpt" in model_class.lower() + " " + model_id_or_path.lower():

elif "gpt" in model_string or "causal" in model_string:
if "gpt" in model_class:
logger.warn(
'DEPRECATION WARNING: please use "causal" as the model class. '
'using "gpt" as the model class will be deprecated in the future.'
)
hfm = CausalHuggingFaceModel(model_id_or_path, **kwargs)
# for GPT-like tokenizers, pad token is not set as it is generally
# inconsequential for autoregressive models
hfm.tokenizer.pad_token = hfm.tokenizer.eos_token
return hfm
elif "bert" in model_class.lower() + " " + model_id_or_path.lower():

elif "bert" in model_string or "masked" in model_string:
if "bert" in model_class:
logger.warn(
'DEPRECATION WARNING: please use "masked" as the model class. '
'using "bert" as the model class will be deprecated in the future.'
)
return MaskedHuggingFaceModel(model_id_or_path)
# in order to support the bigscience bloom-petals
# distributed model, we make a special case.
elif "petals" in model_class.lower() + " " + model_id_or_path.lower():
elif "petals" in model_string:
hfm = DistributedBloomModel(model_id_or_path)
# for GPT-like tokenizers, pad token is not set as it is generally
# inconsequential for autoregressive models
hfm.tokenizer.pad_token = hfm.tokenizer.eos_token
return hfm
elif (
"kenlm" in model_class.lower()
"kenlm" in model_string
or model_id_or_path.endswith(".arpa")
# this may not be a great idea since pytorch models can also end in .bin
or model_id_or_path.endswith(".bin")
):
return KenLMModel(model_id_or_path, **kwargs)
Expand Down

0 comments on commit a159a1a

Please sign in to comment.