Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace multi processing with joblib #477

Merged
merged 11 commits into from
Sep 13, 2021
2 changes: 2 additions & 0 deletions qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def set_conf_from_C(self, config_c):
"kernels": NUM_USABLE_CPU,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
# If joblib_backend is None, use loky
"joblib_backend": "multiprocessing",
"default_disk_cache": 1, # 0:skip/1:use
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
Expand Down
125 changes: 33 additions & 92 deletions qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
import re
import abc
import copy
import time
import queue
import bisect
import logging
import importlib
import traceback
from typing import List, Union

import numpy as np
import pandas as pd
from multiprocessing import Pool

# For supporting multiprocessing in outter code, joblib is used
from joblib import delayed

from .cache import H
from ..config import C
Expand All @@ -29,6 +28,7 @@
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils.resam import resam_calendar
from ..utils.paral import ParallelExt


class ProviderBackendMixin:
Expand Down Expand Up @@ -418,16 +418,7 @@ def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day
"""
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")

def _uri(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=1,
**kwargs,
):
def _uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, **kwargs):
"""Get task uri, used when generating rabbitmq task in qlib_server

Parameters
Expand Down Expand Up @@ -494,51 +485,37 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq):

"""
normalize_column_names = normalize_cache_fields(column_names)
data = dict()
# One process for one task, so that the memory will be freed quicker.
workers = max(min(C.kernels, len(instruments_d)), 1)

if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
# create iterator
if isinstance(instruments_d, dict):
for inst, spans in instruments_d.items():
data[inst] = p.apply_async(
DatasetProvider.expression_calculator,
args=(
inst,
start_time,
end_time,
freq,
normalize_column_names,
spans,
C,
),
)
it = instruments_d.items()
else:
for inst in instruments_d:
data[inst] = p.apply_async(
DatasetProvider.expression_calculator,
args=(
inst,
start_time,
end_time,
freq,
normalize_column_names,
None,
C,
),
it = zip(instruments_d, [None] * len(instruments_d))

inst_l = []
task_l = []
for inst, spans in it:
inst_l.append(inst)
task_l.append(
delayed(DatasetProvider.expression_calculator)(
inst, start_time, end_time, freq, normalize_column_names, spans, C
)
)

p.close()
p.join()
data = dict(
zip(
inst_l,
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l),
)
)

new_data = dict()
for inst in sorted(data.keys()):
if len(data[inst].get()) > 0:
if len(data[inst]) > 0:
# NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order
new_data[inst] = data[inst].get()
new_data[inst] = data[inst]

if len(new_data) > 0:
data = pd.concat(new_data, names=["instrument"], sort=False)
Expand Down Expand Up @@ -755,25 +732,11 @@ def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq
start_time = cal[0]
end_time = cal[-1]
workers = max(min(C.kernels, len(instruments_d)), 1)
if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)

for inst in instruments_d:
p.apply_async(
LocalDatasetProvider.cache_walker,
args=(
inst,
start_time,
end_time,
freq,
column_names,
),
)

p.close()
p.join()
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(
delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names)
for inst in instruments_d
)

@staticmethod
def cache_walker(inst, start_time, end_time, freq, column_names):
Expand Down Expand Up @@ -803,12 +766,7 @@ def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, fu

self.conn.send_request(
request_type="calendar",
request_content={
"start_time": str(start_time),
"end_time": str(end_time),
"freq": freq,
"future": future,
},
request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future},
msg_queue=self.queue,
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
)
Expand Down Expand Up @@ -871,16 +829,7 @@ def set_conn(self, conn):
self.conn = conn
self.queue = queue.Queue()

def dataset(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=0,
return_uri=False,
):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, return_uri=False):
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
"Getting features from a dict of instruments is not recommended because the features will not be "
Expand Down Expand Up @@ -984,15 +933,7 @@ def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)

def features(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=None,
):
def features(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=None):
"""
Parameters:
-----------
Expand Down
13 changes: 11 additions & 2 deletions qlib/utils/paral.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from joblib import Parallel, delayed
import pandas as pd
from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend


class ParallelExt(Parallel):
def __init__(self, *args, **kwargs):
maxtasksperchild = kwargs.pop("maxtasksperchild", None)
super(ParallelExt, self).__init__(*args, **kwargs)
if isinstance(self._backend, MultiprocessingBackend):
self._backend_args["maxtasksperchild"] = maxtasksperchild


def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
Expand Down Expand Up @@ -31,7 +40,7 @@ def _naive_group_apply(df):
return df.groupby(axis=axis, level=level).apply(apply_func)

if n_jobs != 1:
dfs = Parallel(n_jobs=n_jobs)(
dfs = ParallelExt(n_jobs=n_jobs)(
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
)
return pd.concat(dfs, axis=axis).sort_index()
Expand Down
39 changes: 39 additions & 0 deletions tests/misc/test_get_multi_proc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import qlib
from qlib.data import D
from qlib.tests import TestAutoData
from multiprocessing import Pool


def get_features(fields):
qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky")
return D.features(D.instruments("csi300"), fields)


class TestGetData(TestAutoData):
FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")

def test_multi_proc(self):
"""
For testing if it will raise error
"""
iter_n = 2
pool = Pool(iter_n)

res = []
for _ in range(iter_n):
res.append(pool.apply_async(get_features, (self.FIELDS,), {}))

for r in res:
print(r.get())

pool.close()
pool.join()


if __name__ == "__main__":
unittest.main()