Skip to content

Commit

Permalink
Update Downloader utility class to use static dask cluster (#1161)
Browse files Browse the repository at this point in the history
- `file_to_df_loader` is creating a new downloader object for every file list batch. Each downloader creates its own dask cluster which results in very slow processing and DFP module pipelines eventually fail.
- Update Downloader utility class to use static dask cluster

Fixes #1146

Authors:
  - Eli Fajardo (https://github.com/efajardo-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1161
  • Loading branch information
efajardo-nv authored Sep 22, 2023
1 parent a88276e commit f8b774a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ To run the DFP pipelines with the example datasets within the container, run:
--start_time "2022-08-01" \
--duration "60d" \
--train_users generic \
--input_file "./control_messages/duo_payload_load_training_inference.json"
--input_file "./control_messages/duo_payload_load_train_inference.json"
```

* Azure Training Pipeline
Expand Down
32 changes: 19 additions & 13 deletions morpheus/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import multiprocessing as mp
import os
import threading
import typing
from enum import Enum

Expand Down Expand Up @@ -62,12 +63,16 @@ class Downloader:
The heartbeat interval to use when using dask or dask_thread.
"""

# This cluster is shared by all Downloader instances that use dask download method.
_dask_cluster = None

_mutex = threading.RLock()

def __init__(self,
download_method: typing.Union[DownloadMethods, str] = DownloadMethods.DASK_THREAD,
dask_heartbeat_interval: str = "30s"):

self._merlin_distributed = None
self._dask_cluster = None
self._dask_heartbeat_interval = dask_heartbeat_interval

download_method = os.environ.get("MORPHEUS_FILE_DOWNLOAD_TYPE", download_method)
Expand Down Expand Up @@ -96,23 +101,21 @@ def get_dask_cluster(self):
dask_cuda.LocalCUDACluster
"""

if self._dask_cluster is None:
import dask
import dask.distributed
import dask_cuda.utils
with Downloader._mutex:
if Downloader._dask_cluster is None:
import dask_cuda.utils

logger.debug("Creating dask cluster...")
logger.debug("Creating dask cluster...")

# Up the heartbeat interval which can get violated with long download times
dask.config.set({"distributed.client.heartbeat": self._dask_heartbeat_interval})
n_workers = dask_cuda.utils.get_n_gpus()
threads_per_worker = mp.cpu_count() // n_workers
n_workers = dask_cuda.utils.get_n_gpus()
threads_per_worker = mp.cpu_count() // n_workers

self._dask_cluster = dask_cuda.LocalCUDACluster(n_workers=n_workers, threads_per_worker=threads_per_worker)
Downloader._dask_cluster = dask_cuda.LocalCUDACluster(n_workers=n_workers,
threads_per_worker=threads_per_worker)

logger.debug("Creating dask cluster... Done. Dashboard: %s", self._dask_cluster.dashboard_link)
logger.debug("Creating dask cluster... Done. Dashboard: %s", Downloader._dask_cluster.dashboard_link)

return self._dask_cluster
return Downloader._dask_cluster

def get_dask_client(self):
"""
Expand All @@ -124,6 +127,9 @@ def get_dask_client(self):
"""
import dask.distributed

# Up the heartbeat interval which can get violated with long download times
dask.config.set({"distributed.client.heartbeat": self._dask_heartbeat_interval})

if (self._merlin_distributed is None):
self._merlin_distributed = Distributed(client=dask.distributed.Client(self.get_dask_cluster()))

Expand Down
3 changes: 3 additions & 0 deletions tests/examples/digital_fingerprinting/test_dfp_file_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pandas as pd
import pytest

import morpheus.utils.downloader
from _utils import TEST_DIRS
from _utils.dataset_manager import DatasetManager
from morpheus.common import FileTypes
Expand Down Expand Up @@ -99,9 +100,11 @@ def test_constructor(config: Config):


# pylint: disable=redefined-outer-name
@pytest.mark.reload_modules(morpheus.utils.downloader)
@pytest.mark.usefixtures("restore_environ")
@pytest.mark.parametrize('dl_type', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"])
@pytest.mark.parametrize('use_convert_to_dataframe', [True, False])
@pytest.mark.usefixtures("reload_modules")
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.distributed.Client')
@mock.patch('dask_cuda.LocalCUDACluster')
Expand Down
29 changes: 17 additions & 12 deletions tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import fsspec
import pytest

import morpheus.utils.downloader
from _utils import TEST_DIRS
from _utils import import_or_skip
from morpheus.utils.downloader import DOWNLOAD_METHODS_MAP
Expand Down Expand Up @@ -87,29 +88,32 @@ def test_constructor_invalid_dltype(use_env: bool):
Downloader(**kwargs)


@pytest.mark.usefixtures("restore_environ")
@pytest.mark.reload_modules(morpheus.utils.downloader)
@pytest.mark.parametrize("dl_method", ["dask", "dask_thread"])
@mock.patch('dask.config')
@pytest.mark.usefixtures("reload_modules")
@mock.patch('dask_cuda.LocalCUDACluster')
def test_get_dask_cluster(mock_dask_cluster: mock.MagicMock, mock_dask_config: mock.MagicMock, dl_method: str):
def test_get_dask_cluster(mock_dask_cluster: mock.MagicMock, dl_method: str):
mock_dask_cluster.return_value = mock_dask_cluster
downloader = Downloader(download_method=dl_method)
assert downloader.get_dask_cluster() is mock_dask_cluster
downloader1 = Downloader(download_method=dl_method)
assert downloader1.get_dask_cluster() is mock_dask_cluster

# create another downloader then assert that cluster was only created once
downloader2 = Downloader(download_method=dl_method)
downloader2.get_dask_cluster()
assert downloader2.get_dask_cluster() is mock_dask_cluster

mock_dask_config.set.assert_called_once()
mock_dask_cluster.assert_called_once()


@mock.patch('dask.config')
@mock.patch('dask_cuda.LocalCUDACluster')
@pytest.mark.reload_modules(morpheus.utils.downloader)
@pytest.mark.parametrize('dl_method', ["dask", "dask_thread"])
def test_close(mock_dask_cluster: mock.MagicMock, mock_dask_config: mock.MagicMock, dl_method: str):
@pytest.mark.usefixtures("reload_modules")
@mock.patch('dask_cuda.LocalCUDACluster')
def test_close(mock_dask_cluster: mock.MagicMock, dl_method: str):
mock_dask_cluster.return_value = mock_dask_cluster
downloader = Downloader(download_method=dl_method)
assert downloader.get_dask_cluster() is mock_dask_cluster

mock_dask_config.set.assert_called_once()

mock_dask_cluster.close.assert_not_called()
downloader.close()

Expand All @@ -127,7 +131,8 @@ def test_close_noop(mock_dask_cluster: mock.MagicMock, dl_method: str):
mock_dask_cluster.close.assert_not_called()


@pytest.mark.usefixtures("restore_environ")
@pytest.mark.reload_modules(morpheus.utils.downloader)
@pytest.mark.usefixtures("reload_modules", "restore_environ")
@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"])
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.config')
Expand Down

0 comments on commit f8b774a

Please sign in to comment.