diff --git a/qlib/config.py b/qlib/config.py index 0478b7659f..796cc5ca6a 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -92,6 +92,8 @@ def set_conf_from_C(self, config_c): "kernels": NUM_USABLE_CPU, # How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data. "maxtasksperchild": None, + # If joblib_backend is None, use loky + "joblib_backend": "multiprocessing", "default_disk_cache": 1, # 0:skip/1:use "mem_cache_size_limit": 500, # memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar' diff --git a/qlib/data/data.py b/qlib/data/data.py index ccd35006bd..1d51807352 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -9,16 +9,15 @@ import re import abc import copy -import time import queue import bisect -import logging -import importlib -import traceback from typing import List, Union + import numpy as np import pandas as pd -from multiprocessing import Pool + +# For supporting multiprocessing in outter code, joblib is used +from joblib import delayed from .cache import H from ..config import C @@ -29,6 +28,7 @@ from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path from ..utils.resam import resam_calendar +from ..utils.paral import ParallelExt class ProviderBackendMixin: @@ -418,16 +418,7 @@ def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day """ raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") - def _uri( - self, - instruments, - fields, - start_time=None, - end_time=None, - freq="day", - disk_cache=1, - **kwargs, - ): + def _uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, **kwargs): """Get task uri, used when generating rabbitmq task in qlib_server Parameters @@ -494,51 +485,37 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq): """ normalize_column_names = normalize_cache_fields(column_names) - data = dict() # One process for one task, so that the memory will be freed quicker. workers = max(min(C.kernels, len(instruments_d)), 1) - if C.maxtasksperchild is None: - p = Pool(processes=workers) - else: - p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild) + # create iterator if isinstance(instruments_d, dict): - for inst, spans in instruments_d.items(): - data[inst] = p.apply_async( - DatasetProvider.expression_calculator, - args=( - inst, - start_time, - end_time, - freq, - normalize_column_names, - spans, - C, - ), - ) + it = instruments_d.items() else: - for inst in instruments_d: - data[inst] = p.apply_async( - DatasetProvider.expression_calculator, - args=( - inst, - start_time, - end_time, - freq, - normalize_column_names, - None, - C, - ), + it = zip(instruments_d, [None] * len(instruments_d)) + + inst_l = [] + task_l = [] + for inst, spans in it: + inst_l.append(inst) + task_l.append( + delayed(DatasetProvider.expression_calculator)( + inst, start_time, end_time, freq, normalize_column_names, spans, C ) + ) - p.close() - p.join() + data = dict( + zip( + inst_l, + ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l), + ) + ) new_data = dict() for inst in sorted(data.keys()): - if len(data[inst].get()) > 0: + if len(data[inst]) > 0: # NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order - new_data[inst] = data[inst].get() + new_data[inst] = data[inst] if len(new_data) > 0: data = pd.concat(new_data, names=["instrument"], sort=False) @@ -755,25 +732,11 @@ def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq start_time = cal[0] end_time = cal[-1] workers = max(min(C.kernels, len(instruments_d)), 1) - if C.maxtasksperchild is None: - p = Pool(processes=workers) - else: - p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild) - - for inst in instruments_d: - p.apply_async( - LocalDatasetProvider.cache_walker, - args=( - inst, - start_time, - end_time, - freq, - column_names, - ), - ) - p.close() - p.join() + ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)( + delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names) + for inst in instruments_d + ) @staticmethod def cache_walker(inst, start_time, end_time, freq, column_names): @@ -803,12 +766,7 @@ def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, fu self.conn.send_request( request_type="calendar", - request_content={ - "start_time": str(start_time), - "end_time": str(end_time), - "freq": freq, - "future": future, - }, + request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future}, msg_queue=self.queue, msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content], ) @@ -871,16 +829,7 @@ def set_conn(self, conn): self.conn = conn self.queue = queue.Queue() - def dataset( - self, - instruments, - fields, - start_time=None, - end_time=None, - freq="day", - disk_cache=0, - return_uri=False, - ): + def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, return_uri=False): if Inst.get_inst_type(instruments) == Inst.DICT: get_module_logger("data").warning( "Getting features from a dict of instruments is not recommended because the features will not be " @@ -984,15 +933,7 @@ def instruments(self, market="all", filter_pipe=None, start_time=None, end_time= def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): return Inst.list_instruments(instruments, start_time, end_time, freq, as_list) - def features( - self, - instruments, - fields, - start_time=None, - end_time=None, - freq="day", - disk_cache=None, - ): + def features(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=None): """ Parameters: ----------- diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index a640b04ea6..075a1adb84 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -1,8 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from joblib import Parallel, delayed import pandas as pd +from joblib import Parallel, delayed +from joblib._parallel_backends import MultiprocessingBackend + + +class ParallelExt(Parallel): + def __init__(self, *args, **kwargs): + maxtasksperchild = kwargs.pop("maxtasksperchild", None) + super(ParallelExt, self).__init__(*args, **kwargs) + if isinstance(self._backend, MultiprocessingBackend): + self._backend_args["maxtasksperchild"] = maxtasksperchild def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False): @@ -31,7 +40,7 @@ def _naive_group_apply(df): return df.groupby(axis=axis, level=level).apply(apply_func) if n_jobs != 1: - dfs = Parallel(n_jobs=n_jobs)( + dfs = ParallelExt(n_jobs=n_jobs)( delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level) ) return pd.concat(dfs, axis=axis).sort_index() diff --git a/tests/misc/test_get_multi_proc.py b/tests/misc/test_get_multi_proc.py new file mode 100644 index 0000000000..7e27781b6e --- /dev/null +++ b/tests/misc/test_get_multi_proc.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import qlib +from qlib.data import D +from qlib.tests import TestAutoData +from multiprocessing import Pool + + +def get_features(fields): + qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky") + return D.features(D.instruments("csi300"), fields) + + +class TestGetData(TestAutoData): + FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",") + + def test_multi_proc(self): + """ + For testing if it will raise error + """ + iter_n = 2 + pool = Pool(iter_n) + + res = [] + for _ in range(iter_n): + res.append(pool.apply_async(get_features, (self.FIELDS,), {})) + + for r in res: + print(r.get()) + + pool.close() + pool.join() + + +if __name__ == "__main__": + unittest.main()