diff --git a/docs/source/index.rst b/docs/source/index.rst index db34dbfe46..65b4e180bc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -74,6 +74,7 @@ Supported datasets using_doctr/using_datasets using_doctr/sharing_models using_doctr/using_model_export + using_doctr/running_on_aws .. toctree:: diff --git a/docs/source/using_doctr/running_on_aws.rst b/docs/source/using_doctr/running_on_aws.rst new file mode 100644 index 0000000000..a824f354e9 --- /dev/null +++ b/docs/source/using_doctr/running_on_aws.rst @@ -0,0 +1,7 @@ +AWS Lambda +======================== + +AWS Lambda's (read more about Lambda https://aws.amazon.com/lambda/) security policy does not allow you to write anywhere outside `/tmp` directory. +There are two things you need to do to make `doctr` work on lambda: +1. Disable usage of `multiprocessing` package by setting `DOCTR_MULTIPROCESSING_DISABLE` enivronment variable to `TRUE`. You need to do this, because this package uses `/dev/shm` directory for shared memory. +2. Change directory `doctr` uses for caching models. By default it's `~/.cache/doctr` which is outside of `/tmp` on AWS Lambda'. You can do this by setting `DOCTR_CACHE_DIR` enivronment variable. diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index e93f4833aa..55665e4a26 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -92,8 +92,12 @@ def __init__( cache_subdir: Optional[str] = None, **kwargs: Any, ) -> None: + cache_dir = ( + str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr"))) + if cache_dir is None + else cache_dir + ) - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr") if cache_dir is None else cache_dir cache_subdir = "datasets" if cache_subdir is None else cache_subdir file_name = file_name if isinstance(file_name, str) else os.path.basename(url) diff --git a/doctr/utils/data.py b/doctr/utils/data.py index 6f3633f7db..c0a4c8698d 100644 --- a/doctr/utils/data.py +++ b/doctr/utils/data.py @@ -64,13 +64,19 @@ def download_from_url( Returns: the location of the downloaded file + + Note: + You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable. """ if not isinstance(file_name, str): file_name = url.rpartition("/")[-1] - if not isinstance(cache_dir, str): - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr") + cache_dir = ( + str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr"))) + if cache_dir is None + else cache_dir + ) # Check hash in file name if hash_prefix is None: @@ -84,8 +90,19 @@ def download_from_url( logging.info(f"Using downloaded & verified file: {file_path}") return file_path - # Create folder hierarchy - folder_path.mkdir(parents=True, exist_ok=True) + try: + # Create folder hierarchy + folder_path.mkdir(parents=True, exist_ok=True) + except OSError: + error_message = f"Failed creating cache direcotry at {folder_path}" + if os.environ.get("DOCTR_CACHE_DIR", ""): + error_message += " using path from 'DOCTR_CACHE_DIR' environment variable." + else: + error_message += ( + ". You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed." + ) + logging.error(error_message) + raise # Download the file try: print(f"Downloading {url} to {file_path}") diff --git a/doctr/utils/multithreading.py b/doctr/utils/multithreading.py index d0167f889e..682ba19c86 100644 --- a/doctr/utils/multithreading.py +++ b/doctr/utils/multithreading.py @@ -5,9 +5,12 @@ import multiprocessing as mp +import os from multiprocessing.pool import ThreadPool from typing import Any, Callable, Iterable, Iterator, Optional +from doctr.file_utils import ENV_VARS_TRUE_VALUES + __all__ = ["multithread_exec"] @@ -25,11 +28,16 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op Returns: iterator of the function's results using the iterable as inputs + + Notes: + This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. + If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance), + you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'. """ threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) # Single-thread - if threads < 2: + if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES: results = map(func, seq) # Multi-threading else: diff --git a/tests/common/test_utils_data.py b/tests/common/test_utils_data.py new file mode 100644 index 0000000000..91e675ae6a --- /dev/null +++ b/tests/common/test_utils_data.py @@ -0,0 +1,46 @@ +import os +from pathlib import PosixPath +from unittest.mock import patch + +import pytest + +from doctr.utils.data import download_from_url + + +@patch("doctr.utils.data._urlretrieve") +@patch("pathlib.Path.mkdir") +@patch.dict(os.environ, {"HOME": "/"}, clear=True) +def test_download_from_url(mkdir_mock, urlretrieve_mock): + download_from_url("test_url") + urlretrieve_mock.assert_called_with("test_url", PosixPath("/.cache/doctr/test_url")) + + +@patch.dict(os.environ, {"DOCTR_CACHE_DIR": "/test"}, clear=True) +@patch("doctr.utils.data._urlretrieve") +@patch("pathlib.Path.mkdir") +def test_download_from_url_customizing_cache_dir(mkdir_mock, urlretrieve_mock): + download_from_url("test_url") + urlretrieve_mock.assert_called_with("test_url", PosixPath("/test/test_url")) + + +@patch.dict(os.environ, {"HOME": "/"}, clear=True) +@patch("pathlib.Path.mkdir", side_effect=OSError) +@patch("logging.error") +def test_download_from_url_error_creating_directory(logging_mock, mkdir_mock): + with pytest.raises(OSError): + download_from_url("test_url") + logging_mock.assert_called_with( + "Failed creating cache direcotry at /.cache/doctr." + " You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed." + ) + + +@patch.dict(os.environ, {"HOME": "/", "DOCTR_CACHE_DIR": "/test"}, clear=True) +@patch("pathlib.Path.mkdir", side_effect=OSError) +@patch("logging.error") +def test_download_from_url_error_creating_directory_with_env_var(logging_mock, mkdir_mock): + with pytest.raises(OSError): + download_from_url("test_url") + logging_mock.assert_called_with( + "Failed creating cache direcotry at /test using path from 'DOCTR_CACHE_DIR' environment variable." + ) diff --git a/tests/common/test_utils_multithreading.py b/tests/common/test_utils_multithreading.py index 6f0117e83d..72de3a40d8 100644 --- a/tests/common/test_utils_multithreading.py +++ b/tests/common/test_utils_multithreading.py @@ -1,3 +1,7 @@ +import os +from multiprocessing.pool import ThreadPool +from unittest.mock import patch + import pytest from doctr.utils.multithreading import multithread_exec @@ -18,3 +22,10 @@ def test_multithread_exec(input_seq, func, output_seq): assert list(multithread_exec(func, input_seq)) == output_seq assert list(multithread_exec(func, input_seq, 0)) == output_seq + + +@patch.dict(os.environ, {"DOCTR_MULTIPROCESSING_DISABLE": "TRUE"}, clear=True) +def test_multithread_exec_multiprocessing_disable(): + with patch.object(ThreadPool, "map") as mock_tp_map: + multithread_exec(lambda x: x, [1, 2]) + assert not mock_tp_map.called