From 5e9b0475f593000ce57053a4707b951349e5a034 Mon Sep 17 00:00:00 2001 From: Akarsh Gupta <69024958+akarshgupta7@users.noreply.github.com> Date: Thu, 26 Dec 2024 20:37:11 -0800 Subject: [PATCH] Changed the default embedding model to openai. (#1087) * Removed duplicate code in query execution. * Added nltk download logic to support string metrics. * Add ragas to imports. * Added ragas to imports. * Added string metrics for evaluation. * Remove dthe unwanted break statement. * Moves the import to top of file. * Moved the scorer definitions to the init function. * Refactoring. * Removed unused imports. * Moved async calls to the outer function. * Refactor to add error handling and change function names. * Added correctness score to metrics. * Changed the default embedding model to openai. * Removed unused imports. * Removed unused Huggingface embeddings method. --- apps/query-eval/queryeval/driver.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/apps/query-eval/queryeval/driver.py b/apps/query-eval/queryeval/driver.py index 6f6fa3147..21a772d5d 100644 --- a/apps/query-eval/queryeval/driver.py +++ b/apps/query-eval/queryeval/driver.py @@ -26,7 +26,10 @@ import asyncio from ragas.dataset_schema import SingleTurnSample from ragas.metrics import BleuScore, RougeScore, SemanticSimilarity -from ragas.embeddings.base import HuggingfaceEmbeddings, LangchainEmbeddingsWrapper +from ragas.embeddings.base import ( + OpenAIEmbeddings, + LangchainEmbeddingsWrapper, +) from ragas.metrics._factual_correctness import FactualCorrectness from langchain_openai.chat_models import ChatOpenAI from ragas.llms import LangchainLLMWrapper @@ -195,7 +198,7 @@ def __init__( self.rouge_scorer = RougeScore() self.semantic_similarity_scorer = SemanticSimilarity() self.semantic_similarity_scorer.embeddings = LangchainEmbeddingsWrapper( - HuggingfaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") + OpenAIEmbeddings(model="text-embedding-3-small") ) self.correctness_scorer = FactualCorrectness() self.correctness_scorer.llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o"))