Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes needed to be able to use doctr on AWS Lambda #1017

Merged
merged 14 commits into from
Sep 1, 2022
Merged
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Supported datasets
using_doctr/using_datasets
using_doctr/sharing_models
using_doctr/using_model_export
using_doctr/running_on_aws


.. toctree::
Expand Down
7 changes: 7 additions & 0 deletions docs/source/using_doctr/running_on_aws.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
AWS Lambda
mtvch marked this conversation as resolved.
Show resolved Hide resolved
========================

AWS Lambda's (read more about Lambda https://aws.amazon.com/lambda/) security policy does not allow you to write anywhere outside `/tmp` directory.
There are two things you need to do to make `doctr` work on lambda:
1. Disable usage of `multiprocessing` package by setting `DOCTR_MULTIPROCESSING_DISABLE` enivronment variable to `TRUE`. You need to do this, because this package uses `/dev/shm` directory for shared memory.
2. Change directory `doctr` uses for caching models. By default it's `~/.cache/doctr` which is outside of `/tmp` on AWS Lambda'. You can do this by setting `DOCTR_CACHE_DIR` enivronment variable.
6 changes: 5 additions & 1 deletion doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def __init__(
cache_subdir: Optional[str] = None,
**kwargs: Any,
) -> None:
cache_dir = (
str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
if cache_dir is None
else cache_dir
)

cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr") if cache_dir is None else cache_dir
cache_subdir = "datasets" if cache_subdir is None else cache_subdir

file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
Expand Down
25 changes: 21 additions & 4 deletions doctr/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def download_from_url(

Returns:
the location of the downloaded file

Note:
You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable.
"""

if not isinstance(file_name, str):
file_name = url.rpartition("/")[-1]

if not isinstance(cache_dir, str):
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr")
cache_dir = (
str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
if cache_dir is None
else cache_dir
)

# Check hash in file name
if hash_prefix is None:
Expand All @@ -84,8 +90,19 @@ def download_from_url(
logging.info(f"Using downloaded & verified file: {file_path}")
return file_path

# Create folder hierarchy
folder_path.mkdir(parents=True, exist_ok=True)
try:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# Create folder hierarchy
folder_path.mkdir(parents=True, exist_ok=True)
except OSError:
error_message = f"Failed creating cache direcotry at {folder_path}"
if os.environ.get("DOCTR_CACHE_DIR", ""):
error_message += " using path from 'DOCTR_CACHE_DIR' environment variable."
else:
error_message += (
". You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed."
)
logging.error(error_message)
raise
# Download the file
try:
print(f"Downloading {url} to {file_path}")
Expand Down
10 changes: 9 additions & 1 deletion doctr/utils/multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@


import multiprocessing as mp
import os
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, Iterable, Iterator, Optional

from doctr.file_utils import ENV_VARS_TRUE_VALUES

__all__ = ["multithread_exec"]


Expand All @@ -25,11 +28,16 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op

Returns:
iterator of the function's results using the iterable as inputs

Notes:
This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory.
If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance),
you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'.
"""

threads = threads if isinstance(threads, int) else min(16, mp.cpu_count())
# Single-thread
if threads < 2:
if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES:
results = map(func, seq)
# Multi-threading
else:
Expand Down
46 changes: 46 additions & 0 deletions tests/common/test_utils_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from pathlib import PosixPath
from unittest.mock import patch

import pytest

from doctr.utils.data import download_from_url


@patch("doctr.utils.data._urlretrieve")
@patch("pathlib.Path.mkdir")
@patch.dict(os.environ, {"HOME": "/"}, clear=True)
def test_download_from_url(mkdir_mock, urlretrieve_mock):
download_from_url("test_url")
urlretrieve_mock.assert_called_with("test_url", PosixPath("/.cache/doctr/test_url"))


@patch.dict(os.environ, {"DOCTR_CACHE_DIR": "/test"}, clear=True)
@patch("doctr.utils.data._urlretrieve")
@patch("pathlib.Path.mkdir")
def test_download_from_url_customizing_cache_dir(mkdir_mock, urlretrieve_mock):
download_from_url("test_url")
urlretrieve_mock.assert_called_with("test_url", PosixPath("/test/test_url"))


@patch.dict(os.environ, {"HOME": "/"}, clear=True)
@patch("pathlib.Path.mkdir", side_effect=OSError)
@patch("logging.error")
def test_download_from_url_error_creating_directory(logging_mock, mkdir_mock):
with pytest.raises(OSError):
download_from_url("test_url")
logging_mock.assert_called_with(
"Failed creating cache direcotry at /.cache/doctr."
" You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed."
)


@patch.dict(os.environ, {"HOME": "/", "DOCTR_CACHE_DIR": "/test"}, clear=True)
@patch("pathlib.Path.mkdir", side_effect=OSError)
@patch("logging.error")
def test_download_from_url_error_creating_directory_with_env_var(logging_mock, mkdir_mock):
with pytest.raises(OSError):
download_from_url("test_url")
logging_mock.assert_called_with(
"Failed creating cache direcotry at /test using path from 'DOCTR_CACHE_DIR' environment variable."
)
11 changes: 11 additions & 0 deletions tests/common/test_utils_multithreading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
from multiprocessing.pool import ThreadPool
from unittest.mock import patch

import pytest

from doctr.utils.multithreading import multithread_exec
Expand All @@ -18,3 +22,10 @@
def test_multithread_exec(input_seq, func, output_seq):
assert list(multithread_exec(func, input_seq)) == output_seq
assert list(multithread_exec(func, input_seq, 0)) == output_seq


@patch.dict(os.environ, {"DOCTR_MULTIPROCESSING_DISABLE": "TRUE"}, clear=True)
def test_multithread_exec_multiprocessing_disable():
with patch.object(ThreadPool, "map") as mock_tp_map:
multithread_exec(lambda x: x, [1, 2])
assert not mock_tp_map.called