Skip to content

Commit

Permalink
enable autorag to automatically generate the evaluation dataset and e…
Browse files Browse the repository at this point in the history
…valuate the RAG system (#36)

Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
  • Loading branch information
XuhuiRen authored Jun 28, 2024
1 parent 3f57e69 commit b24bff5
Show file tree
Hide file tree
Showing 11 changed files with 858 additions and 0 deletions.
14 changes: 14 additions & 0 deletions evals/benchmark/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

ground_truth_file: ./ground_truth.jsonl
use_openai_key: False
search_type: [similarity, mmr]
k: [1]
fetch_k: [5]
score_threshold: [0.3]
top_n: [1]
temperature: [0.01]
top_k: [1, 3, 5]
top_p: [0.1]
repetition_penalty: [1.0]
9 changes: 9 additions & 0 deletions evals/benchmark/ground_truth.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{"question": "What are Nike's primary business activities as of the fiscal year ended May 31, 2023?", "context": ["Our principal business activity is the design, development and worldwide marketing and selling of athletic footwear, apparel, equipment, accessories and services."], "ground_truth": "Nike's primary business activities include the design, development, worldwide marketing, and selling of athletic footwear, apparel, equipment, accessories, and services."}
{"question": "How does Nike categorize its product offerings?", "context": ["Our NIKE Brand product offerings are aligned around our consumer construct focused on Men's, Women's and Kids'. We also design products specifically for the Jordan Brand and Converse."], "ground_truth": "Nike categorizes its product offerings around consumer constructs focused on Men's, Women's, and Kids'. They also design products specifically for the Jordan Brand and Converse."}
{"question": "What was Nike's total revenue from non-U.S. operations for fiscal year 2023?", "context": ["For fiscal 2023, non-U.S. NIKE Brand and Converse sales accounted for approximately 57% of total revenues."], "ground_truth": "For fiscal year 2023, non-U.S. operations accounted for approximately 57% of Nike's total revenues."}
{"question": "How does Nike ensure the innovation and quality of its products?", "context": ["We place considerable emphasis on innovation and high-quality construction in the development and manufacturing of our products."], "ground_truth": "Nike emphasizes technical innovation and high-quality construction in the development and manufacturing of its products. They employ specialists in various fields and utilize research committees and advisory boards comprising athletes and other experts."}
{"question": "What are the risks associated with Nike's international operations?", "context": ["Our international operations and sources of supply are subject to the usual risks of doing business abroad, such as the implementation of, or potential changes in, foreign and domestic trade policies."], "ground_truth": "Nike's international operations are subject to risks such as changes in foreign and domestic trade policies, increases in import duties, and political and economic instability, among others."}
{"question": "How does Nike view the role of intellectual property in its business strategy?", "context": ["We believe that our intellectual property rights are important to our brand, our success and our competitive position."], "ground_truth": "Nike considers its intellectual property rights critical to its brand, success, and competitive position. They actively pursue protection of these rights and vigorously defend them against third-party infringement."}
{"question": "What is Nike's approach to diversity, equity, and inclusion within its workforce?", "context": ["Diversity, equity and inclusion ('DE&I') is a strategic priority for NIKE and we are committed to having an increasingly diverse team and culture."], "ground_truth": "Nike prioritizes fostering an inclusive and accessible workplace, aiming to expand representation across all dimensions of diversity. They have specific goals for increasing representation among women globally and U.S. racial and ethnic minorities by fiscal 2025."}
{"question": "How does Nike address the environmental impact of its operations?", "context": ["Our mission is aligned with our deep commitment to maintaining an environment where all NIKE employees have the opportunity to reach their full potential."], "ground_truth": "Nike is focused on sustainability, aiming to create products more sustainably, such as through using environmentally friendly materials and processes, and investing in global communities to promote a more equitable future."}
{"question": "What financial impact did Nike's U.S. operations have in fiscal year 2023?", "context": ["For fiscal 2023, NIKE Brand and Converse sales in the United States accounted for approximately 43% of total revenues."], "ground_truth": "Nike Brand and Converse sales in the United States accounted for approximately 43% of total revenues for fiscal 2023."}
92 changes: 92 additions & 0 deletions evals/benchmark/ragas_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
search_type="similarity"
k=1
fetch_k=5
score_threshold=0.3
top_n=1
max_chuck_size=256
temperature=0.01
top_k=1
top_p=0.1
repetition_penalty=1.0

for var in "$@"
do
case $var in
--ground_truth_file=*)
ground_truth_file=$(echo $var |cut -f2 -d=)
;;
--use_openai_key=*)
use_openai_key=$(echo $var |cut -f2 -d=)
;;
--search_type=*)
search_type=$(echo $var |cut -f2 -d=)
;;
--k=*)
k=$(echo $var |cut -f2 -d=)
;;
--fetch_k=*)
fetch_k=$(echo $var |cut -f2 -d=)
;;
--score_threshold=*)
score_threshold=$(echo ${var} |cut -f2 -d=)
;;
--top_n=*)
top_n=$(echo ${var} |cut -f2 -d=)
;;
--temperature=*)
temperature=$(echo $var |cut -f2 -d=)
;;
--top_k=*)
top_k=$(echo $var |cut -f2 -d=)
;;
--top_p=*)
top_p=$(echo $var |cut -f2 -d=)
;;
--repetition_penalty=*)
repetition_penalty=$(echo ${var} |cut -f2 -d=)
;;
esac
done

}

# run_benchmark
function run_benchmark {

if [[ ${use_openai_key} == True ]]; then
use_openai_key="--use_openai_key"
else
use_openai_key=""
fi

python -u ../evaluation/autorag/evaluation/ragas_evaluation_benchmark.py \
--ground_truth_file ${ground_truth_file} \
--input_path ${input_path} \
--use_openai_key ${use_openai_key} \
--search_type ${search_type} \
--k ${k} \
--fetch_k ${fetch_k} \
--score_threshold ${score_threshold} \
--top_n ${top_n} \
--temperature ${temperature} \
--top_k ${top_k} \
--top_p ${top_p} \
--repetition_penalty ${repetition_penalty}
}

main "$@"
71 changes: 71 additions & 0 deletions evals/benchmark/run_rag_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
import subprocess

import jsonlines
import yaml


def read_yaml_file(file_path):
with open(file_path, "r") as stream:
try:
return yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)


if __name__ == "__main__":
if os.path.exists("result_ragas.jsonl"):
os.remove("result_ragas.jsonl")
script_path = "ragas_benchmark.sh"

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
args = parser.parse_args()

data = read_yaml_file(args.config)
data = {k: [str(item) for item in v] if isinstance(v, list) else str(v) for k, v in data.items()}

ground_truth_file = data["ground_truth_file"]
use_openai_key = data["use_openai_key"]
search_types = data["search_type"]
ks = data["k"]
fetch_ks = data["fetch_k"]
score_thresholds = data["score_threshold"]
top_ns = data["top_n"]
temperatures = data["temperature"]
top_ks = data["top_k"]
top_ps = data["top_p"]
repetition_penaltys = data["repetition_penalty"]

for search_type in search_types:
for k in ks:
for fetch_k in fetch_ks:
for score_threshold in score_thresholds:
for top_n in top_ns:
for temperature in temperatures:
for top_k in top_ks:
for top_p in top_ps:
for repetition_penalty in repetition_penaltys:
subprocess.run(
[
"bash",
script_path,
"--ground_truth_file=" + ground_truth_file,
"--use_openai_key=" + str(use_openai_key),
"--search_type=" + search_type,
"--k=" + k,
"--fetch_k=" + fetch_k,
"--score_threshold=" + score_threshold,
"--top_n=" + top_n,
"--temperature=" + temperature,
"--top_k=" + top_k,
"--top_p=" + top_p,
"--repetition_penalty=" + repetition_penalty,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
62 changes: 62 additions & 0 deletions evals/evaluation/autorag/data_generation/gen_answer_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
import re

import jsonlines
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer # pylint: disable=E0401

from .prompt_dict import TRUTHGENERATE_PROMPT


def load_documents(document_file_jsonl_path):
document_list = []
with open(document_file_jsonl_path) as file:
for stu in jsonlines.Reader(file):
passages = [stu["query"], stu["pos"][0]]
document_list.append(passages)
return document_list


def answer_generate(llm, base_dir, file_json_path, generation_config):
documents = load_documents(base_dir)

try:
if isinstance(llm, str):
use_endpoint = False
tokenizer = AutoTokenizer.from_pretrained(llm)
llm = AutoModelForCausalLM.from_pretrained(llm, device_map="auto", torch_dtype=torch.float16)
llm.eval()
else:
use_endpoint = True
llm = llm
except:
print("Please check the setting llm!")

for question, context in enumerate(documents):
if context and question:
prompt = TRUTHGENERATE_PROMPT.format(question=question, context=context)
if not use_endpoint:
with torch.no_grad():
model_input = tokenizer(prompt, return_tensors="pt")
res = llm.generate(**model_input, generation_config=generation_config)[0]
res = tokenizer.decode(res, skip_special_tokens=True)
else:
res = llm.invoke(prompt)

res = res[res.find("Generated ground_truth:") :]
res = re.sub("Generated ground_truth:", "", res)
res = re.sub("---", "", res)

result_str = res.replace("#", " ").replace(r"\t", " ").replace("\n", " ").replace("\n\n", " ").strip()

if result_str and not result_str.isspace():
data = {
"question": question,
"context": [context],
"ground_truth": result_str,
}
with jsonlines.open(file_json_path, "a") as file_json:
file_json.write(data)
105 changes: 105 additions & 0 deletions evals/evaluation/autorag/data_generation/gen_eval_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import os

from comps.dataprep.utils import document_loader
from langchain_community.llms import HuggingFaceEndpoint
from sentence_transformers import SentenceTransformer
from transformers import GenerationConfig

from .gen_answer_dataset import answer_generate
from .gen_hard_negative import mine_hard_negatives
from .llm_generate_raw_data import raw_data_generation
from .utils import similarity_check

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--llm", type=str)
parser.add_argument("--embedding_model", type=str)
parser.add_argument("--input", type=str)
parser.add_argument("--output", type=str, default="./data")

parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--repetition_penalty", type=float, default=2.0)
parser.add_argument("--max_new_tokens", type=int, default=48)
parser.add_argument("--do_sample", type=bool, default=True)
parser.add_argument("--num_beams", type=int, default=2)
parser.add_argument("--num_return_sequences", type=int, default=2)
parser.add_argument("--use_cache", type=bool, default=True)

parser.add_argument("--range_for_sampling", type=str, default="2-10")
parser.add_argument("--negative_number", type=int, default=5)
parser.add_argument("--use_gpu_for_searching", type=bool, default=False)

parser.add_argument("--similarity_threshold", type=float, default=0.6)

args = parser.parse_args()

llm_model = args.llm
input_path = args.input
output = args.output

generation_config = GenerationConfig(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
num_beams=args.num_beams,
num_return_sequences=args.num_return_sequences,
use_cache=args.use_cache,
)

embedding_model = SentenceTransformer(args.embedding_model)

try:
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=512,
top_k=args.top_k,
top_p=args.top_p,
typical_p=args.typical_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
streaming=args.streaming,
timeout=600,
)
except:
print("Did not find the llm endpoint service, load model from huggingface hub as instead.")

try:
if not os.path.exists(output):
os.mkdir(output)
else:
if os.path.exists(os.path.join(output, "raw.jsonl")):
os.remove(os.path.join(output, "raw.jsonl"))
if os.path.exists(os.path.join(output, "minedHN.jsonl")):
os.remove(os.path.join(output, "minedHN.jsonl"))
if os.path.exists(os.path.join(output, "minedHN_split.jsonl")):
os.remove(os.path.join(output, "minedHN_split.jsonl"))
except:
pass

output_path = os.path.join(output, "raw_query.jsonl")
raw_data_generation(llm, input_path, output_path, generation_config)

output_hn_path = os.path.join(output, "query_doc.jsonl")
mine_hard_negatives(
embedding_model,
output_path,
output_hn_path,
args.range_for_sampling,
args.negative_number,
)

output_json_split_path = os.path.join(output, "query_doc_cleaned.jsonl")
similarity_check(output_hn_path, output_json_split_path, embedding_model, args.similarity_threshold)

output_answer_path = os.path.join(output, "answer.jsonl")
answer_generate(llm, input, output, generation_config)
Loading

0 comments on commit b24bff5

Please sign in to comment.