Skip to content

Commit

Permalink
Embedder fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NyanNyanovich committed Jan 3, 2025
1 parent 97a490d commit a24b9d5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 392 deletions.
3 changes: 3 additions & 0 deletions nyan/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __call__(self, texts: List[str]) -> torch.Tensor:
batch_embeddings = (
last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
)
elif self.pooling_method == "cls":
hidden_states = out.last_hidden_state
batch_embeddings = hidden_states[:, 0, :]
if self.normalize:
batch_embeddings = torch.nn.functional.normalize(batch_embeddings)
start_index = batch_num * self.batch_size
Expand Down
22 changes: 15 additions & 7 deletions scripts/eval_embeddings_toloka.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@ def read_jsonl(path):
model_path = sys.argv[2]


#embedder = Embedder(model_path, pooling_method="mean", normalize=True, text_prefix="query: ")
embedder = Embedder(model_path, pooling_method="cls", normalize=True)

markup = read_jsonl(markup_path)
embedder = Embedder(model_path)
y_pred, y_true = [], []

y_true, texts = [], []
for record in markup:
result = record["result"]
url1 = record["first_url"]
url2 = record["second_url"]
embeddings = embedder([record["first_text"], record["second_text"]])
distance = cosine(embeddings[0], embeddings[1])
y_pred.append(distance)
texts.extend([record["first_text"], record["second_text"]])
label = 1 - int(result == "ok")
y_true.append(label)

embeddings = embedder(texts)

y_pred = []
for e1, e2 in zip(embeddings[::2], embeddings[1::2]):
distance = cosine(e1, e2)
y_pred.append(distance)

assert len(y_true) == len(y_pred)

print("AUC:", roc_auc_score(y_true, y_pred))
precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
f1_scores = 2 * recall * precision / (recall + precision)
Expand Down
Loading

0 comments on commit a24b9d5

Please sign in to comment.