Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve!: use multiprocessing in fetcher #1

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 88 additions & 59 deletions fast_s3/fetcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import io
import multiprocessing
import warnings
from pathlib import Path
from typing import List, Union
from queue import Empty
from typing import Generator, List, Tuple, Union

import boto3

from .file import File, Status
from .transfer_manager import transfer_manager


class Fetcher:
Expand All @@ -15,70 +18,96 @@ def __init__(
aws_secret_access_key: str,
region_name: str,
bucket_name: str,
ordered=True,
buffer_size=1024,
buffer_size: int = 1000,
n_workers=32,
**transfer_manager_kwargs,
worker_batch_size=100,
callback=lambda x: x,
ordered: bool = False,
):
self.paths = paths
self.ordered = ordered
self.buffer_size = buffer_size
self.transfer_manager = transfer_manager(
endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
n_workers=n_workers,
**transfer_manager_kwargs,
)
self.paths = multiprocessing.Manager().list(list(enumerate(paths))[::-1])
self.bucket_name = bucket_name
self.files: List[File] = []
self.current_path_index = 0

def __len__(self):
return len(self.paths)

def __iter__(self):
for _ in range(self.buffer_size):
self.queue_download_()
self.endpoint_url = endpoint_url
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.region_name = region_name
self.n_workers = n_workers
self.buffer_size = min(buffer_size, len(paths))
self.worker_batch_size = worker_batch_size
self.ordered = ordered
self.callback = callback

if self.ordered:
for _ in range(len(self)):
yield self.process_index(0)
if ordered:
# TODO: fix this issue
warnings.warn(
"buffer_size is ignored when ordered=True which can cause out of memory"
)
self.results = multiprocessing.Manager().dict()
self.result_order = multiprocessing.Manager().list(range(len(paths)))
else:
for _ in range(len(self)):
for index, file in enumerate(self.files):
if file.future.done():
break
else:
index = 0
yield self.process_index(index)
self.file_queue = multiprocessing.Queue(maxsize=buffer_size)

def process_index(self, index):
file = self.files.pop(index)
self.queue_download_()
try:
file.future.result()
return file.with_status(Status.done)
except Exception as e:
return file.with_status(Status.error, exception=e)
def _create_s3_client(self):
return boto3.client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.region_name,
)

def queue_download_(self):
if self.current_path_index < len(self):
buffer = io.BytesIO()
path = self.paths[self.current_path_index]
self.files.append(
File(
buffer=buffer,
future=self.transfer_manager.download(
fileobj=buffer,
bucket=self.bucket_name,
key=str(path),
def download_batch(self, batch: List[Tuple[int, Union[Path, str]]]):
client = self._create_s3_client()
for index, path in batch:
try:
file = File(
content=self.callback(
client.get_object(Bucket=self.bucket_name, Key=str(path))[
"Body"
].read()
),
path=path,
status=Status.succeeded,
)
)
self.current_path_index += 1
except Exception as e:
file = File(content=None, path=path, status=Status.failed, exception=e)
if self.ordered:
self.results[index] = file
else:
self.file_queue.put(file)

def _worker(self):
while len(self.paths) > 0:
batch = []
for _ in range(min(self.worker_batch_size, len(self.paths))):
try:
index, path = self.paths.pop()
batch.append((index, path))
except IndexError:
break
if len(batch) > 0:
self.download_batch(batch)

def __iter__(self) -> Generator[File, None, None]:
workers = []
for _ in range(self.n_workers):
worker_process = multiprocessing.Process(target=self._worker)
worker_process.start()
workers.append(worker_process)

if self.ordered:
for i in self.result_order:
while any(p.is_alive() for p in workers) and i not in self.results:
continue # wait for the item to appear
yield self.results.pop(i)
else:
while any(p.is_alive() for p in workers) or not self.file_queue.empty():
try:
yield self.file_queue.get(timeout=1)
except Empty:
pass

for worker in workers:
worker.join()

def close(self):
self.transfer_manager.shutdown()
def __len__(self):
return len(self.paths)
14 changes: 5 additions & 9 deletions fast_s3/file.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import io
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union

from pydantic import BaseModel
from s3transfer.futures import TransferFuture


class Status(str, Enum):
pending = "pending"
done = "done"
error = "error"
succeeded = "succeeded"
failed = "failed"


class File(BaseModel, arbitrary_types_allowed=True):
buffer: io.BytesIO
future: TransferFuture
content: Any
path: Union[str, Path]
status: Status = Status.pending
status: Status
exception: Optional[Exception] = None

def with_status(self, status: Status, exception: Optional[Exception] = None):
Expand Down