Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass kwargs allowing model_kwargs and tokenizer_kwargs to be passed. … #44

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ ENV/
env.bak/
venv.bak/
.envrc
.conda/

# mkdocs documentation
/site
Expand Down
6 changes: 5 additions & 1 deletion rerankers/models/colbert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
verbose: int = 1,
query_token: str = "[unused0]",
document_token: str = "[unused1]",
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, self.verbose)
Expand All @@ -230,10 +231,13 @@ def __init__(
f"Loading model {model_name}, this might take a while...",
self.verbose,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = (
ColBERTModel.from_pretrained(
model_name,
**model_kwargs
)
.to(self.device)
.to(self.dtype)
Expand Down
15 changes: 12 additions & 3 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
cutoff_layers: Optional[List[int]] = None,
compress_ratio: Optional[int] = None,
compress_layer: Optional[List[int]] = None,
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
Expand All @@ -50,16 +51,24 @@ def __init__(
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
tokenizer_trust_remote_code = tokenizer_kwargs.pop("trust_remote_code", True)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
model_name_or_path,
trust_remote_code=tokenizer_trust_remote_code,
**tokenizer_kwargs,
)
self.max_sequence_length = max_sequence_length
self.tokenizer.model_max_length = self.max_sequence_length
self.tokenizer.padding_side = "right"
model_kwargs = kwargs.get("model_kwargs", {})
model_trust_remote_code = model_kwargs.pop("trust_remote_code", True)

self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype
model_name_or_path,
trust_remote_code=model_trust_remote_code,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
self.model.eval()

Expand Down
16 changes: 11 additions & 5 deletions rerankers/models/t5ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from rerankers.documents import Document


import torch

from rerankers.results import RankedResults, Result
from rerankers.utils import (
vprint,
Expand Down Expand Up @@ -89,7 +87,8 @@ def __init__(
token_false: str = "auto",
token_true: str = "auto",
return_logits: bool = False,
inputs_template: str = "Query: {query} Document: {text} Relevant:"
inputs_template: str = "Query: {query} Document: {text} Relevant:",
**kwargs,
):
"""
Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py
Expand All @@ -113,11 +112,18 @@ def __init__(
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path, torch_dtype=self.dtype
model_name_or_path,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
**tokenizer_kwargs,
)

token_false, token_true = _get_output_tokens(
model_name_or_path=model_name_or_path,
Expand Down
12 changes: 10 additions & 2 deletions rerankers/models/transformer_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,26 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 16,
verbose: int = 1,
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
self.dtype = get_dtype(dtype, self.device, self.verbose)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, torch_dtype=self.dtype
model_name_or_path,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
vprint(f"Loaded model {model_name_or_path}", self.verbose)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
**tokenizer_kwargs,
)
self.ranking_type = "pointwise"
self.batch_size = batch_size

Expand Down
3 changes: 2 additions & 1 deletion rerankers/results.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, Optional, List
from typing import List, Optional, Union

from pydantic import BaseModel, validator

from rerankers.documents import Document
Expand Down
9 changes: 7 additions & 2 deletions tests/test_crossenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ def test_transformer_ranker_rank(mock_rank):
expected_results = RankedResults(
results=[
Result(
document=Document(id=1, text="Gone with the wind is an all-time classic"),
document=Document(
doc_id=1, text="Gone with the wind is an all-time classic"
),
score=1.6181640625,
rank=1,
),
Result(
document=Document(id=0, text="Gone with the wind is a masterclass in bad storytelling."),
document=Document(
doc_id=0,
text="Gone with the wind is a masterclass in bad storytelling.",
),
score=0.88427734375,
rank=2,
),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
def test_ranked_results_functions():
results = RankedResults(
results=[
Result(document=Document(id=0, text="Doc 0"), score=0.9, rank=2),
Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1),
Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2),
Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1),
],
query="Test Query",
has_scores=True,
Expand All @@ -20,7 +20,7 @@ def test_ranked_results_functions():


def test_result_attributes():
result = Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1)
result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1)
assert result.doc_id == 1
assert result.text == "Doc 1"
assert result.score == 0.95
Expand Down