""" Use skypilot to launch managed jobs that will run the embedding calculation for RAG. This script is responsible for splitting the input dataset up among several workers, then using skypilot to launch managed jobs for each worker. We use compute_embeddings.yaml to define the managed job info. """ #!/usr/bin/env python3 import argparse import sky def calculate_job_range(start_idx: int, end_idx: int, job_rank: int, total_jobs: int) -> tuple[int, int]: """Calculate the range of indices this job should process. Args: start_idx: Global start index end_idx: Global end index job_rank: Current job's rank (0-based) total_jobs: Total number of jobs Returns: Tuple of [job_start_idx, job_end_idx) """ total_range = end_idx - start_idx chunk_size = total_range // total_jobs remainder = total_range % total_jobs # Distribute remainder across first few jobs job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder) if job_rank < remainder: chunk_size += 1 job_end = job_start + chunk_size return job_start, job_end def main(): parser = argparse.ArgumentParser( description='Launch batch RAG embedding computation jobs') parser.add_argument('--start-idx', type=int, default=0, help='Global start index in dataset') parser.add_argument( '--end-idx', type=int, # this is the last index of the reddit post dataset default=109740, help='Global end index in dataset, not inclusive') parser.add_argument('--num-jobs', type=int, default=1, help='Number of jobs to partition the work across') parser.add_argument("--embedding_bucket_name", type=str, default="sky-rag-embeddings", help="Name of the bucket to store embeddings") args = parser.parse_args() # Load the task template task = sky.Task.from_yaml('compute_embeddings.yaml') # Launch jobs for each partition for job_rank in range(args.num_jobs): # Calculate index range for this job job_start, job_end = calculate_job_range(args.start_idx, args.end_idx, job_rank, args.num_jobs) # Update environment variables for this job task_copy = task.update_envs({ 'START_IDX': job_start, 'END_IDX': job_end, 'EMBEDDINGS_BUCKET_NAME': args.embedding_bucket_name, }) sky.jobs.launch(task_copy, name=f'rag-compute-{job_start}-{job_end}') if __name__ == '__main__': main()