-
Notifications
You must be signed in to change notification settings - Fork 588
/
Copy pathbatch_compute_embeddings.py
86 lines (67 loc) · 2.75 KB
/
batch_compute_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Use skypilot to launch managed jobs that will run the embedding calculation for RAG.
This script is responsible for splitting the input dataset up among several workers,
then using skypilot to launch managed jobs for each worker. We use compute_embeddings.yaml
to define the managed job info.
"""
#!/usr/bin/env python3
import argparse
import sky
def calculate_job_range(start_idx: int, end_idx: int, job_rank: int,
total_jobs: int) -> tuple[int, int]:
"""Calculate the range of indices this job should process.
Args:
start_idx: Global start index
end_idx: Global end index
job_rank: Current job's rank (0-based)
total_jobs: Total number of jobs
Returns:
Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs
# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
job_end = job_start + chunk_size
return job_start, job_end
def main():
parser = argparse.ArgumentParser(
description='Launch batch RAG embedding computation jobs')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Global start index in dataset')
parser.add_argument(
'--end-idx',
type=int,
# this is the last index of the reddit post dataset
default=109740,
help='Global end index in dataset, not inclusive')
parser.add_argument('--num-jobs',
type=int,
default=1,
help='Number of jobs to partition the work across')
parser.add_argument("--embedding_bucket_name",
type=str,
default="sky-rag-embeddings",
help="Name of the bucket to store embeddings")
args = parser.parse_args()
# Load the task template
task = sky.Task.from_yaml('compute_embeddings.yaml')
# Launch jobs for each partition
for job_rank in range(args.num_jobs):
# Calculate index range for this job
job_start, job_end = calculate_job_range(args.start_idx, args.end_idx,
job_rank, args.num_jobs)
# Update environment variables for this job
task_copy = task.update_envs({
'START_IDX': job_start,
'END_IDX': job_end,
'EMBEDDINGS_BUCKET_NAME': args.embedding_bucket_name,
})
sky.jobs.launch(task_copy, name=f'rag-compute-{job_start}-{job_end}')
if __name__ == '__main__':
main()