diff --git a/docs/_static/img/online_serving.png b/docs/_static/img/online_serving.png
new file mode 100644
index 0000000000..8647ebe773
Binary files /dev/null and b/docs/_static/img/online_serving.png differ
diff --git a/docs/advanced/serial.rst b/docs/advanced/serial.rst
index 8c0f837467..e0840069bf 100644
--- a/docs/advanced/serial.rst
+++ b/docs/advanced/serial.rst
@@ -14,6 +14,9 @@ Serializable Class
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
+However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
+
+Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here `_).
Example
==========================
diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst
new file mode 100644
index 0000000000..56a3137f9f
--- /dev/null
+++ b/docs/advanced/task_management.rst
@@ -0,0 +1,89 @@
+.. _task_management:
+
+=================================
+Task Management
+=================================
+.. currentmodule:: qlib
+
+
+Introduction
+=============
+
+The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
+To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
+With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
+
+This whole process can be used in `Online Serving <../component/online.html>`_.
+
+An example of the entire process is shown `here `_.
+
+Task Generating
+===============
+A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
+The specific task template can be viewed in
+`Task Section <../component/workflow.html#task-section>`_.
+Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
+
+Here is the base class of ``TaskGen``:
+
+.. autoclass:: qlib.workflow.task.gen.TaskGen
+ :members:
+
+``Qlib`` provides a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments.
+This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
+
+Task Storing
+===============
+To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_.
+``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
+Users **MUST** finish the configuration of `MongoDB `_ when using this module.
+
+Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
+
+ .. code-block:: python
+
+ from qlib.config import C
+ C["mongo"] = {
+ "task_url" : "mongodb://localhost:27017/", # your MongoDB url
+ "task_db_name" : "rolling_db" # database name
+ }
+
+.. autoclass:: qlib.workflow.task.manage.TaskManager
+ :members:
+
+More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
+
+Task Training
+===============
+After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
+``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
+An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
+It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
+
+.. autofunction:: qlib.workflow.task.manage.run_task
+
+Meanwhile, ``Qlib`` provides a module called ``Trainer``.
+
+.. autoclass:: qlib.model.trainer.Trainer
+ :members:
+
+``Trainer`` will train a list of tasks and return a list of model recorders.
+``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
+If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
+`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
+
+Task Collecting
+===============
+To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
+
+`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
+
+`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
+For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
+
+`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
+For example: {C1: object, C2: object} ---``Ensemble``---> object
+
+So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
+
+For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_.
\ No newline at end of file
diff --git a/docs/component/online.rst b/docs/component/online.rst
new file mode 100644
index 0000000000..accc936dd4
--- /dev/null
+++ b/docs/component/online.rst
@@ -0,0 +1,46 @@
+.. _online:
+
+=================================
+Online Serving
+=================================
+.. currentmodule:: qlib
+
+
+Introduction
+=============
+
+.. image:: ../_static/img/online_serving.png
+ :align: center
+
+
+In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
+``Online Serving`` is a set of modules for online models using the latest data,
+which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
+
+`Here `_ are several examples for reference, which demonstrate different features of ``Online Serving``.
+If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
+The `examples `_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
+
+Online Manager
+=============
+
+.. automodule:: qlib.workflow.online.manager
+ :members:
+
+Online Strategy
+=============
+
+.. automodule:: qlib.workflow.online.strategy
+ :members:
+
+Online Tool
+=============
+
+.. automodule:: qlib.workflow.online.utils
+ :members:
+
+Updater
+=============
+
+.. automodule:: qlib.workflow.online.update
+ :members:
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 3fa35fc60d..803aa97d2d 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing
Qlib Recorder: Experiment Management
Analysis: Evaluation & Results Analysis
+ Online Serving: Online Management & Strategy & Tool
.. toctree::
:maxdepth: 3
@@ -50,6 +51,7 @@ Document Structure
Building Formulaic Alphas
Online & Offline mode
Serialization
+ Task Management
.. toctree::
:maxdepth: 3
diff --git a/docs/reference/api.rst b/docs/reference/api.rst
index 3167d8a622..57f61f18b1 100644
--- a/docs/reference/api.rst
+++ b/docs/reference/api.rst
@@ -154,6 +154,70 @@ Record Template
.. automodule:: qlib.workflow.record_temp
:members:
+Task Management
+====================
+
+
+TaskGen
+--------------------
+.. automodule:: qlib.workflow.task.gen
+ :members:
+
+TaskManager
+--------------------
+.. automodule:: qlib.workflow.task.manage
+ :members:
+
+Trainer
+--------------------
+.. automodule:: qlib.model.trainer
+ :members:
+
+Collector
+--------------------
+.. automodule:: qlib.workflow.task.collect
+ :members:
+
+Group
+--------------------
+.. automodule:: qlib.model.ens.group
+ :members:
+
+Ensemble
+--------------------
+.. automodule:: qlib.model.ens.ensemble
+ :members:
+
+Utils
+--------------------
+.. automodule:: qlib.workflow.task.utils
+ :members:
+
+
+Online Serving
+====================
+
+
+Online Manager
+--------------------
+.. automodule:: qlib.workflow.online.manager
+ :members:
+
+Online Strategy
+--------------------
+.. automodule:: qlib.workflow.online.strategy
+ :members:
+
+Online Tool
+--------------------
+.. automodule:: qlib.workflow.online.utils
+ :members:
+
+RecordUpdater
+--------------------
+.. automodule:: qlib.workflow.online.update
+ :members:
+
Utils
====================
@@ -162,4 +226,7 @@ Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
- :members:
\ No newline at end of file
+ :members:
+
+
+
\ No newline at end of file
diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst
index 15aa957d1a..32c17ff837 100644
--- a/docs/start/initialization.rst
+++ b/docs/start/initialization.rst
@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
+- `mongo`
+ Type: dict, optional parameter, the setting of `MongoDB `_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
+ Users need finished `installation `_ firstly, and run it in a fixed URL.
+
+ .. code-block:: Python
+
+ # For example, you can initialize qlib below
+ qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
+ "task_url": "mongodb://localhost:27017/", # your mongo url
+ "task_db_name": "rolling_db", # the database name of Task Management
+ })
diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py
new file mode 100644
index 0000000000..4f3ac04b15
--- /dev/null
+++ b/examples/model_rolling/task_manager_rolling.py
@@ -0,0 +1,159 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+This example shows how a TrainerRM works based on TaskManager with rolling tasks.
+After training, how to collect the rolling results will be shown in task_collecting.
+"""
+
+from pprint import pprint
+
+import fire
+import qlib
+from qlib.config import REG_CN
+from qlib.workflow import R
+from qlib.workflow.task.gen import RollingGen, task_generator
+from qlib.workflow.task.manage import TaskManager
+from qlib.workflow.task.collect import RecorderCollector
+from qlib.model.ens.group import RollingGroup
+from qlib.model.trainer import TrainerRM
+
+
+data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi100",
+}
+
+dataset_config = {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+}
+
+record_config = [
+ {
+ "class": "SignalRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+ {
+ "class": "SigAnaRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+]
+
+# use lgb
+task_lgb_config = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+# use xgboost
+task_xgboost_config = {
+ "model": {
+ "class": "XGBModel",
+ "module_path": "qlib.contrib.model.xgboost",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+
+class RollingTaskExample:
+ def __init__(
+ self,
+ provider_uri="~/.qlib/qlib_data/cn_data",
+ region=REG_CN,
+ task_url="mongodb://10.0.0.4:27017/",
+ task_db_name="rolling_db",
+ experiment_name="rolling_exp",
+ task_pool="rolling_task",
+ task_config=[task_xgboost_config, task_lgb_config],
+ rolling_step=550,
+ rolling_type=RollingGen.ROLL_SD,
+ ):
+ # TaskManager config
+ mongo_conf = {
+ "task_url": task_url,
+ "task_db_name": task_db_name,
+ }
+ qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
+ self.experiment_name = experiment_name
+ self.task_pool = task_pool
+ self.task_config = task_config
+ self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
+
+ # Reset all things to the first status, be careful to save important data
+ def reset(self):
+ print("========== reset ==========")
+ TaskManager(task_pool=self.task_pool).remove()
+ exp = R.get_exp(experiment_name=self.experiment_name)
+ for rid in exp.list_recorders():
+ exp.delete_recorder(rid)
+
+ def task_generating(self):
+ print("========== task_generating ==========")
+ tasks = task_generator(
+ tasks=self.task_config,
+ generators=self.rolling_gen, # generate different date segments
+ )
+ pprint(tasks)
+ return tasks
+
+ def task_training(self, tasks):
+ print("========== task_training ==========")
+ trainer = TrainerRM(self.experiment_name, self.task_pool)
+ trainer.train(tasks)
+
+ def task_collecting(self):
+ print("========== task_collecting ==========")
+
+ def rec_key(recorder):
+ task_config = recorder.load_object("task")
+ model_key = task_config["model"]["class"]
+ rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
+ return model_key, rolling_key
+
+ def my_filter(recorder):
+ # only choose the results of "LGBModel"
+ model_key, rolling_key = rec_key(recorder)
+ if model_key == "LGBModel":
+ return True
+ return False
+
+ collector = RecorderCollector(
+ experiment=self.experiment_name,
+ process_list=RollingGroup(),
+ rec_key_func=rec_key,
+ rec_filter_func=my_filter,
+ )
+ print(collector())
+
+ def main(self):
+ self.reset()
+ tasks = self.task_generating()
+ self.task_training(tasks)
+ self.task_collecting()
+
+
+if __name__ == "__main__":
+ ## to see the whole process with your own parameters, use the command below
+ # python task_manager_rolling.py main --experiment_name="your_exp_name"
+ fire.Fire(RollingTaskExample)
diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py
new file mode 100644
index 0000000000..4bb5022ee0
--- /dev/null
+++ b/examples/online_srv/online_management_simulate.py
@@ -0,0 +1,146 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+This example is about how can simulate the OnlineManager based on rolling tasks.
+"""
+
+import fire
+import qlib
+from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
+from qlib.workflow import R
+from qlib.workflow.online.manager import OnlineManager
+from qlib.workflow.online.strategy import RollingStrategy
+from qlib.workflow.task.gen import RollingGen
+from qlib.workflow.task.manage import TaskManager
+
+
+data_handler_config = {
+ "start_time": "2018-01-01",
+ "end_time": "2018-10-31",
+ "fit_start_time": "2018-01-01",
+ "fit_end_time": "2018-03-31",
+ "instruments": "csi100",
+}
+
+dataset_config = {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2018-01-01", "2018-03-31"),
+ "valid": ("2018-04-01", "2018-05-31"),
+ "test": ("2018-06-01", "2018-09-10"),
+ },
+ },
+}
+
+record_config = [
+ {
+ "class": "SignalRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+ {
+ "class": "SigAnaRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+]
+
+# use lgb model
+task_lgb_config = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+# use xgboost model
+task_xgboost_config = {
+ "model": {
+ "class": "XGBModel",
+ "module_path": "qlib.contrib.model.xgboost",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+
+class OnlineSimulationExample:
+ def __init__(
+ self,
+ provider_uri="~/.qlib/qlib_data/cn_data",
+ region="cn",
+ exp_name="rolling_exp",
+ task_url="mongodb://10.0.0.4:27017/",
+ task_db_name="rolling_db",
+ task_pool="rolling_task",
+ rolling_step=80,
+ start_time="2018-09-10",
+ end_time="2018-10-31",
+ tasks=[task_xgboost_config, task_lgb_config],
+ ):
+ """
+ Init OnlineManagerExample.
+
+ Args:
+ provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
+ region (str, optional): the stock region. Defaults to "cn".
+ exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
+ task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
+ task_db_name (str, optional): database name. Defaults to "rolling_db".
+ task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
+ rolling_step (int, optional): the step for rolling. Defaults to 80.
+ start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
+ end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
+ tasks (dict or list[dict]): a set of the task config waiting for rolling and training
+ """
+ self.exp_name = exp_name
+ self.task_pool = task_pool
+ self.start_time = start_time
+ self.end_time = end_time
+ mongo_conf = {
+ "task_url": task_url,
+ "task_db_name": task_db_name,
+ }
+ qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
+ self.rolling_gen = RollingGen(
+ step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
+ ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
+ self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
+ self.rolling_online_manager = OnlineManager(
+ RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
+ trainer=self.trainer,
+ begin_time=self.start_time,
+ )
+ self.tasks = tasks
+
+ # Reset all things to the first status, be careful to save important data
+ def reset(self):
+ TaskManager(self.task_pool).remove()
+ exp = R.get_exp(experiment_name=self.exp_name)
+ for rid in exp.list_recorders():
+ exp.delete_recorder(rid)
+
+ # Run this to run all workflow automatically
+ def main(self):
+ print("========== reset ==========")
+ self.reset()
+ print("========== simulate ==========")
+ self.rolling_online_manager.simulate(end_time=self.end_time)
+ print("========== collect results ==========")
+ print(self.rolling_online_manager.get_collector()())
+ print("========== signals ==========")
+ print(self.rolling_online_manager.get_signals())
+
+
+if __name__ == "__main__":
+ ## to run all workflow automatically with your own parameters, use the command below
+ # python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
+ fire.Fire(OnlineSimulationExample)
diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py
new file mode 100644
index 0000000000..25b8b2a0c0
--- /dev/null
+++ b/examples/online_srv/rolling_online_management.py
@@ -0,0 +1,181 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+This example shows how OnlineManager works with rolling tasks.
+There are four parts including first train, routine 1, add strategy and routine 2.
+Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
+Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
+Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
+Finally, the OnlineManager will finish second routine and update all strategies.
+"""
+
+import os
+import fire
+import qlib
+from qlib.workflow import R
+from qlib.workflow.online.strategy import RollingStrategy
+from qlib.workflow.task.gen import RollingGen
+from qlib.workflow.online.manager import OnlineManager
+
+data_handler_config = {
+ "start_time": "2013-01-01",
+ "end_time": "2020-09-25",
+ "fit_start_time": "2013-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi100",
+}
+
+dataset_config = {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2013-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2015-12-31"),
+ "test": ("2016-01-01", "2020-07-10"),
+ },
+ },
+}
+
+record_config = [
+ {
+ "class": "SignalRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+ {
+ "class": "SigAnaRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+]
+
+# use lgb model
+task_lgb_config = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+# use xgboost model
+task_xgboost_config = {
+ "model": {
+ "class": "XGBModel",
+ "module_path": "qlib.contrib.model.xgboost",
+ },
+ "dataset": dataset_config,
+ "record": record_config,
+}
+
+
+class RollingOnlineExample:
+ def __init__(
+ self,
+ provider_uri="~/.qlib/qlib_data/cn_data",
+ region="cn",
+ task_url="mongodb://10.0.0.4:27017/",
+ task_db_name="rolling_db",
+ rolling_step=550,
+ tasks=[task_xgboost_config],
+ add_tasks=[task_lgb_config],
+ ):
+ mongo_conf = {
+ "task_url": task_url, # your MongoDB url
+ "task_db_name": task_db_name, # database name
+ }
+ qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
+ self.tasks = tasks
+ self.add_tasks = add_tasks
+ self.rolling_step = rolling_step
+ strategies = []
+ for task in tasks:
+ name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
+ strategies.append(
+ RollingStrategy(
+ name_id,
+ task,
+ RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
+ )
+ )
+
+ self.rolling_online_manager = OnlineManager(strategies)
+
+ _ROLLING_MANAGER_PATH = (
+ ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
+ )
+
+ # Reset all things to the first status, be careful to save important data
+ def reset(self):
+ for task in self.tasks + self.add_tasks:
+ name_id = task["model"]["class"]
+ exp = R.get_exp(experiment_name=name_id)
+ for rid in exp.list_recorders():
+ exp.delete_recorder(rid)
+
+ if os.path.exists(self._ROLLING_MANAGER_PATH):
+ os.remove(self._ROLLING_MANAGER_PATH)
+
+ def first_run(self):
+ print("========== reset ==========")
+ self.reset()
+ print("========== first_run ==========")
+ self.rolling_online_manager.first_train()
+ print("========== collect results ==========")
+ print(self.rolling_online_manager.get_collector()())
+ print("========== dump ==========")
+ self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
+
+ def routine(self):
+ print("========== load ==========")
+ self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
+ print("========== routine ==========")
+ self.rolling_online_manager.routine()
+ print("========== collect results ==========")
+ print(self.rolling_online_manager.get_collector()())
+ print("========== signals ==========")
+ print(self.rolling_online_manager.get_signals())
+ print("========== dump ==========")
+ self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
+
+ def add_strategy(self):
+ print("========== load ==========")
+ self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
+ print("========== add strategy ==========")
+ strategies = []
+ for task in self.add_tasks:
+ name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
+ strategies.append(
+ RollingStrategy(
+ name_id,
+ task,
+ RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
+ )
+ )
+ self.rolling_online_manager.add_strategy(strategies=strategies)
+ print("========== dump ==========")
+ self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
+
+ def main(self):
+ self.first_run()
+ self.routine()
+ self.add_strategy()
+ self.routine()
+
+
+if __name__ == "__main__":
+ ####### to train the first version's models, use the command below
+ # python rolling_online_management.py first_run
+
+ ####### to update the models and predictions after the trading time, use the command below
+ # python rolling_online_management.py routine
+
+ ####### to define your own parameters, use `--`
+ # python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
+ fire.Fire(RollingOnlineExample)
diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py
new file mode 100644
index 0000000000..228bc0dacb
--- /dev/null
+++ b/examples/online_srv/update_online_pred.py
@@ -0,0 +1,91 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+This example shows how OnlineTool works when we need update prediction.
+There are two parts including first_train and update_online_pred.
+Firstly, we will finish the training and set the trained models to the `online` models.
+Next, we will finish updating online predictions.
+"""
+import fire
+import qlib
+from qlib.config import REG_CN
+from qlib.model.trainer import task_train
+from qlib.workflow.online.utils import OnlineToolR
+
+data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi100",
+}
+
+task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
+ "record": {
+ "class": "SignalRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+}
+
+
+class UpdatePredExample:
+ def __init__(
+ self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
+ ):
+ qlib.init(provider_uri=provider_uri, region=region)
+ self.experiment_name = experiment_name
+ self.online_tool = OnlineToolR(self.experiment_name)
+ self.task_config = task_config
+
+ def first_train(self):
+ rec = task_train(self.task_config, experiment_name=self.experiment_name)
+ self.online_tool.reset_online_tag(rec) # set to online model
+
+ def update_online_pred(self):
+ self.online_tool.update_online_pred()
+
+ def main(self):
+ self.first_train()
+ self.update_online_pred()
+
+
+if __name__ == "__main__":
+ ## to train a model and set it to online model, use the command below
+ # python update_online_pred.py first_train
+ ## to update online predictions once a day, use the command below
+ # python update_online_pred.py update_online_pred
+ ## to see the whole process with your own parameters, use the command below
+ # python update_online_pred.py main --experiment_name="your_exp_name"
+ fire.Fire(UpdatePredExample)
diff --git a/qlib/__init__.py b/qlib/__init__.py
index 99035e5014..4fd48f8c19 100644
--- a/qlib/__init__.py
+++ b/qlib/__init__.py
@@ -3,6 +3,7 @@
__version__ = "0.6.3.99"
+__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
@@ -10,12 +11,13 @@
import logging
import platform
import subprocess
+from pathlib import Path
+from .log import get_module_logger
# init qlib
def init(default_conf="client", **kwargs):
from .config import C
- from .log import get_module_logger
from .data.cache import H
H.clear()
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
def _mount_nfs_uri(C):
- from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs):
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
+
+
+def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
+ """
+ If users are building a project follow the following pattern.
+ - Qlib is a sub folder in project path
+ - There is a file named `config.yaml` in qlib.
+
+ For example:
+ If your project file system stucuture follows such a pattern
+
+ /
+ - config.yaml
+ - ...some folders...
+ - qlib/
+
+ This folder will return
+
+ NOTE: link is not supported here.
+
+
+ This method is often used when
+ - user want to use a relative config path instead of hard-coding qlib config path in code
+
+ Raises
+ ------
+ FileNotFoundError:
+ If project path is not found
+ """
+ if cur_path is None:
+ cur_path = Path(__file__).absolute().resolve()
+ while True:
+ if (cur_path / config_name).exists():
+ return cur_path
+ if cur_path == cur_path.parent:
+ raise FileNotFoundError("We can't find the project path")
+ cur_path = cur_path.parent
+
+
+def auto_init(**kwargs):
+ """
+ This function will init qlib automatically with following priority
+ - Find the project configuration and init qlib
+ - The parsing process will be affected by the `conf_type` of the configuration file
+ - Init qlib with default config
+ """
+
+ try:
+ pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
+ except FileNotFoundError:
+ init(**kwargs)
+ else:
+
+ conf_pp = pp / "config.yaml"
+ with conf_pp.open() as f:
+ conf = yaml.safe_load(f)
+
+ conf_type = conf.get("conf_type", "origin")
+ if conf_type == "origin":
+ # The type of config is just like original qlib config
+ init_from_yaml_conf(conf_pp, **kwargs)
+ elif conf_type == "ref":
+ # This config type will be more convenient in following scenario
+ # - There is a shared configure file and you don't want to edit it inplace.
+ # - The shared configure may be updated later and you don't want to copy it.
+ # - You have some customized config.
+ qlib_conf_path = conf["qlib_cfg"]
+ qlib_conf_update = conf.get("qlib_cfg_update")
+ init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
+ logger = get_module_logger("Initialization")
+ logger.info(f"Auto load project config: {conf_pp}")
diff --git a/qlib/config.py b/qlib/config.py
index 75ab0fa3e8..4dedf75d06 100644
--- a/qlib/config.py
+++ b/qlib/config.py
@@ -33,6 +33,9 @@ def __getattr__(self, attr):
raise AttributeError(f"No such {attr} in self._config")
+ def get(self, key, default=None):
+ return self.__dict__["_config"].get(key, default)
+
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -131,7 +134,7 @@ def set_conf_from_C(self, config_c):
},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
},
- # Defatult config for experiment manager
+ # Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -140,6 +143,11 @@ def set_conf_from_C(self, config_c):
"default_exp_name": "Experiment",
},
},
+ # Default config for MongoDB
+ "mongo": {
+ "task_url": "mongodb://localhost:27017/",
+ "task_db_name": "default_task_db",
+ },
}
MODE_CONF = {
@@ -310,8 +318,22 @@ def register(self):
# clean up experiment when python program ends
experiment_exit_handler()
+ # Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
+ self.reset_qlib_version()
+
self._registered = True
+ def reset_qlib_version(self):
+ import qlib
+
+ reset_version = self.get("qlib_reset_version", None)
+ if reset_version is not None:
+ qlib.__version__ = reset_version
+ else:
+ qlib.__version__ = getattr(qlib, "__version__bak")
+ # Due to a bug? that converting __version__ to _QlibConfig__version__bak
+ # Using __version__bak instead of __version__
+
@property
def registered(self):
return self._registered
diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py
index 970b032d6b..be2016ea32 100644
--- a/qlib/contrib/data/handler.py
+++ b/qlib/contrib/data/handler.py
@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
+ # FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)
diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py
index b3eaac7a33..206561aed9 100644
--- a/qlib/data/dataset/__init__.py
+++ b/qlib/data/dataset/__init__.py
@@ -27,7 +27,7 @@ def __init__(self, **kwargs):
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
- The data could specify the info to caculate the essential data for preparation
+ The data could specify the info to calculate the essential data for preparation
"""
self.setup_data(**kwargs)
super().__init__()
@@ -92,7 +92,7 @@ def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple
handler : Union[dict, DataHandler]
handler could be:
- - insntance of `DataHandler`
+ - instance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
@@ -112,8 +112,9 @@ def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
- self.handler = init_instance_by_config(handler, accept_types=DataHandler)
+ self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
+ self.fetch_kwargs = {}
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
@@ -123,7 +124,7 @@ def config(self, handler_kwargs: dict = None, **kwargs):
Parameters
----------
handler_kwargs : dict
- Config of DataHanlder, which could include the following arguments:
+ Config of DataHandler, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
@@ -147,11 +148,11 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs):
Parameters
----------
handler_kwargs : dict
- init arguments of DataHanlder, which could include the following arguments:
+ init arguments of DataHandler, which could include the following arguments:
- init_type : Init Type of Handler
- - enable_cache : wheter to enable cache
+ - enable_cache : whether to enable cache
"""
super().setup_data(**kwargs)
@@ -171,7 +172,10 @@ def _prepare_seg(self, slc: slice, **kwargs):
----------
slc : slice
"""
- return self.handler.fetch(slc, **kwargs)
+ if hasattr(self, "fetch_kwargs"):
+ return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
+ else:
+ return self.handler.fetch(slc, **kwargs)
def prepare(
self,
@@ -199,6 +203,12 @@ def prepare(
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
+ kwargs :
+ The parameters that kwargs may contain:
+ flt_col : str
+ It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
+ This parameter is only supported when it is an instance of TSDatasetH.
+
Returns
-------
Union[List[pd.DataFrame], pd.DataFrame]:
@@ -231,7 +241,7 @@ class TSDataSampler:
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
- It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
+ It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
@@ -243,7 +253,9 @@ class TSDataSampler:
"""
- def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
+ def __init__(
+ self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
+ ):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -265,6 +277,11 @@ def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: s
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
+ flt_data : pd.Series
+ a column of data(True or False) to filter data.
+ None:
+ kepp all data
+
"""
self.start = start
self.end = end
@@ -272,23 +289,51 @@ def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: s
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
- self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
- # NOTE: append last line with full NaN for better performance in `__getitem__`
- self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
+
+ kwargs = {"object": self.data}
+ if dtype is not None:
+ kwargs["dtype"] = dtype
+
+ self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
+ # NOTE:
+ # - append last line with full NaN for better performance in `__getitem__`
+ # - Keep the same dtype will result in a better performance
+ self.data_arr = np.append(
+ self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
+ )
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx
- self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_df, self.idx_map = self.build_index(self.data)
+ self.data_index = deepcopy(self.data.index)
+
+ if flt_data is not None:
+ self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
+ self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
+ self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
+
+ self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
+ del self.data # save memory
+
+ @staticmethod
+ def flt_idx_map(flt_data, idx_map):
+ idx = 0
+ new_idx_map = {}
+ for i, exist in enumerate(flt_data):
+ if exist:
+ new_idx_map[idx] = idx_map[i]
+ idx += 1
+ return new_idx_map
+
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
- return self.data.index[self.start_idx : self.end_idx]
+ return self.data_index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -432,7 +477,7 @@ class TSDatasetH(DatasetH):
(T)ime-(S)eries Dataset (H)andler
- Covnert the tabular data to Time-Series data
+ Convert the tabular data to Time-Series data
Requirements analysis
@@ -461,7 +506,7 @@ def setup_data(self, **kwargs):
cal = sorted(cal)
self.cal = cal
- def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
+ def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
@@ -470,6 +515,25 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
+ return data
+
+ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
+ """
+ split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
+ """
+ dtype = kwargs.pop("dtype", None)
+ start, end = slc.start, slc.stop
+ flt_col = kwargs.pop("flt_col", None)
+ # TSDatasetH will retrieve more data for complete
+ data = self._prepare_raw_seg(slc, **kwargs)
+
+ flt_kwargs = deepcopy(kwargs)
+ if flt_col is not None:
+ flt_kwargs["col_set"] = flt_col
+ flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
+ assert len(flt_data.columns) == 1
+ else:
+ flt_data = None
- tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
+ tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds
diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py
index a6ca658d16..c6338832a5 100644
--- a/qlib/data/dataset/handler.py
+++ b/qlib/data/dataset/handler.py
@@ -7,7 +7,7 @@
import logging
import warnings
from inspect import getfullargspec
-from typing import Union, Tuple, List, Iterator, Optional
+from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -36,7 +36,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
- Any order of the index level can be suported (The order will be implied in the data).
+ Any order of the index level can be supported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -51,6 +51,9 @@ class DataHandler(Serializable):
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
+
+ Tips for improving the performance of datahandler
+ - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
def __init__(
@@ -74,7 +77,7 @@ def __init__(
data_loader : Union[dict, str, DataLoader]
data loader to load the data.
init_data :
- intialize the original data in the constructor.
+ initialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
"""
@@ -125,7 +128,7 @@ def config(self, **kwargs):
def setup_data(self, enable_cache: bool = False):
"""
- Set Up the data in case of running intialization for multiple time
+ Set Up the data in case of running initialization for multiple time
It is responsible for maintaining following variable
1) self._data
@@ -163,6 +166,7 @@ def fetch(
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
+ proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -185,6 +189,14 @@ def fetch(
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
+ proc_func: Callable
+ - Give a hook for processing data before fetching
+ - An example to explain the necessity of the hook:
+ - A Dataset learned some processors to process data which is related to data segmentation
+ - It will apply them every time when preparing data.
+ - The learned processor require the dataframe remains the same format when fitting and applying
+ - However the data format will change according to the parameters.
+ - So the processors should be applied to the underlayer data.
squeeze : bool
whether squeeze columns and index
@@ -193,8 +205,15 @@ def fetch(
-------
pd.DataFrame.
"""
+ if proc_func is None:
+ df = self._data
+ else:
+ # FIXME: fetching by time first will be more friendly to `proc_func`
+ # Copy in case of `proc_func` changing the data inplace....
+ df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
+
# Fetch column first will be more friendly to SepDataFrame
- df = self._fetch_df_by_col(self._data, col_set)
+ df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -261,6 +280,10 @@ def get_range_iterator(
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
+
+ Tips to improving the performance of data handler
+ - To reduce the memory cost
+ - `drop_raw=True`: this will modify the data inplace on raw data;
"""
# data key
@@ -430,7 +453,7 @@ def config(self, processor_kwargs: dict = None, **kwargs):
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
"""
- Set up the data in case of running intialization for multiple time
+ Set up the data in case of running initialization for multiple time
Parameters
----------
@@ -474,6 +497,7 @@ def fetch(
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
+ proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -488,12 +512,18 @@ def fetch(
select a set of meaningful columns.(e.g. features, columns).
data_key : str
the data to fetch: DK_*.
+ proc_func: Callable
+ please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
+ if proc_func is not None:
+ # FIXME: fetch by time first will be more friendly to proc_func
+ # Copy incase of `proc_func` changing the data inplace....
+ df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py
index 58aca1d4f7..2ad110b89d 100644
--- a/qlib/data/dataset/loader.py
+++ b/qlib/data/dataset/loader.py
@@ -13,6 +13,7 @@
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
+from qlib.log import get_module_logger
class DataLoader(abc.ABC):
@@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader):
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
+
+ TODO: What make this module not that easy to use.
+ - For online scenario
+ - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
@@ -265,7 +270,7 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
- LOG.warning(f"instruments[{instruments}] is ignored")
+ get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(
diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py
index 7635a4127b..fce22ddfcf 100644
--- a/qlib/data/dataset/processor.py
+++ b/qlib/data/dataset/processor.py
@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import abc
+from typing import Union, Text
import numpy as np
import pandas as pd
import copy
@@ -14,7 +15,7 @@
EPS = 1e-12
-def get_group_columns(df: pd.DataFrame, group: str):
+def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
"""
get a group of columns from multi-index columns DataFrame
diff --git a/qlib/model/ens/__init__.py b/qlib/model/ens/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py
new file mode 100644
index 0000000000..4fa6a5ec63
--- /dev/null
+++ b/qlib/model/ens/ensemble.py
@@ -0,0 +1,115 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
+"""
+
+from typing import Union
+import pandas as pd
+from qlib.utils import FLATTEN_TUPLE, flatten_dict
+
+
+class Ensemble:
+ """Merge the ensemble_dict into an ensemble object.
+
+ For example: {Rollinga_b: object, Rollingb_c: object} -> object
+
+ When calling this class:
+
+ Args:
+ ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
+
+ Returns:
+ object: the ensemble object
+ """
+
+ def __call__(self, ensemble_dict: dict, *args, **kwargs):
+ raise NotImplementedError(f"Please implement the `__call__` method.")
+
+
+class SingleKeyEnsemble(Ensemble):
+
+ """
+ Extract the object if there is only one key and value in the dict. Make the result more readable.
+ {Only key: Only value} -> Only value
+
+ If there is more than 1 key or less than 1 key, then do nothing.
+ Even you can run this recursively to make dict more readable.
+
+ NOTE: Default runs recursively.
+
+ When calling this class:
+
+ Args:
+ ensemble_dict (dict): the dict. The key of the dict will be ignored.
+
+ Returns:
+ dict: the readable dict.
+ """
+
+ def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
+ if not isinstance(ensemble_dict, dict):
+ return ensemble_dict
+ if recursion:
+ tmp_dict = {}
+ for k, v in ensemble_dict.items():
+ tmp_dict[k] = self(v, recursion)
+ ensemble_dict = tmp_dict
+ keys = list(ensemble_dict.keys())
+ if len(keys) == 1:
+ ensemble_dict = ensemble_dict[keys[0]]
+ return ensemble_dict
+
+
+class RollingEnsemble(Ensemble):
+
+ """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
+
+ NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
+
+ When calling this class:
+
+ Args:
+ ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
+ The key of the dict will be ignored.
+
+ Returns:
+ pd.DataFrame: the complete result of rolling.
+ """
+
+ def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
+ artifact_list = list(ensemble_dict.values())
+ artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
+ artifact = pd.concat(artifact_list)
+ # If there are duplicated predition, use the latest perdiction
+ artifact = artifact[~artifact.index.duplicated(keep="last")]
+ artifact = artifact.sort_index()
+ return artifact
+
+
+class AverageEnsemble(Ensemble):
+ """
+ Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
+
+ NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
+
+ When calling this class:
+
+ Args:
+ ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
+ The key of the dict will be ignored.
+
+ Returns:
+ pd.DataFrame: the complete result of averaging and standardizing.
+ """
+
+ def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
+ # need to flatten the nested dict
+ ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
+ values = list(ensemble_dict.values())
+ results = pd.concat(values, axis=1)
+ results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
+ results = results.mean(axis=1)
+ results = results.sort_index()
+ return results
diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py
new file mode 100644
index 0000000000..7f45b06a5c
--- /dev/null
+++ b/qlib/model/ens/group.py
@@ -0,0 +1,113 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Group can group a set of objects based on `group_func` and change them to a dict.
+After group, we provide a method to reduce them.
+
+For example:
+
+group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
+reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
+
+"""
+
+from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
+from typing import Callable, Union
+from joblib import Parallel, delayed
+
+
+class Group:
+ """Group the objects based on dict"""
+
+ def __init__(self, group_func=None, ens: Ensemble = None):
+ """
+ Init Group.
+
+ Args:
+ group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
+
+ For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
+
+ Defaults to None.
+
+ ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
+ """
+ self._group_func = group_func
+ self._ens_func = ens
+
+ def group(self, *args, **kwargs) -> dict:
+ """
+ Group a set of objects and change them to a dict.
+
+ For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
+
+ Returns:
+ dict: grouped dict
+ """
+ if isinstance(getattr(self, "_group_func", None), Callable):
+ return self._group_func(*args, **kwargs)
+ else:
+ raise NotImplementedError(f"Please specify valid `group_func`.")
+
+ def reduce(self, *args, **kwargs) -> dict:
+ """
+ Reduce grouped dict.
+
+ For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
+
+ Returns:
+ dict: reduced dict
+ """
+ if isinstance(getattr(self, "_ens_func", None), Callable):
+ return self._ens_func(*args, **kwargs)
+ else:
+ raise NotImplementedError(f"Please specify valid `_ens_func`.")
+
+ def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
+ """
+ Group the ungrouped_dict into different groups.
+
+ Args:
+ ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
+
+ Returns:
+ dict: grouped_dict like {G1: object, G2: object}
+ n_jobs: how many progress you need.
+ verbose: the print mode for Parallel.
+ """
+
+ # NOTE: The multiprocessing will raise error if you use `Serializable`
+ # Because the `Serializable` will affect the behaviors of pickle
+ grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
+
+ key_l = []
+ job_l = []
+ for key, value in grouped_dict.items():
+ key_l.append(key)
+ job_l.append(delayed(Group.reduce)(self, value))
+ return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
+
+
+class RollingGroup(Group):
+ """Group the rolling dict"""
+
+ def group(self, rolling_dict: dict) -> dict:
+ """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
+
+ NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
+
+ Args:
+ rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
+
+ Returns:
+ dict: grouped dict
+ """
+ grouped_dict = {}
+ for key, values in rolling_dict.items():
+ if isinstance(key, tuple):
+ grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
+ return grouped_dict
+
+ def __init__(self):
+ super().__init__(ens=RollingEnsemble())
diff --git a/qlib/model/task.py b/qlib/model/task.py
deleted file mode 100644
index f29f513a4e..0000000000
--- a/qlib/model/task.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import abc
-import typing
-
-
-class TaskGen(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def __call__(self, *args, **kwargs) -> typing.List[dict]:
- """
- generate
-
- Parameters
- ----------
- args, kwargs:
- The info for generating tasks
- Example 1):
- input: a specific task template
- output: rolling version of the tasks
- Example 2):
- input: a specific task template
- output: a set of tasks with different losses
-
- Returns
- -------
- typing.List[dict]:
- A list of tasks
- """
- pass
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
index f0bc0b780a..fd76e67284 100644
--- a/qlib/model/trainer.py
+++ b/qlib/model/trainer.py
@@ -1,42 +1,446 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-from qlib.utils import init_instance_by_config, flatten_dict
+"""
+The Trainer will train a list of tasks and return a list of model recorders.
+There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
+
+This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
+In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
+
+``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
+"""
+
+import socket
+from typing import Callable, List
+
+from qlib.data.dataset import Dataset
+from qlib.model.base import Model
+from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
+from qlib.workflow.recorder import Recorder
+from qlib.workflow.task.manage import TaskManager, run_task
-def task_train(task_config: dict, experiment_name):
+def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
- task based training
+ Begin task training to start a recorder and save the task config.
- Parameters
- ----------
- task_config : dict
- A dict describes a task setting.
+ Args:
+ task_config (dict): the config of a task
+ experiment_name (str): the name of experiment
+ recorder_name (str): the given name will be the recorder name. None for using rid.
+
+ Returns:
+ Recorder: the model recorder
"""
+ with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
+ R.log_params(**flatten_dict(task_config))
+ R.save_objects(**{"task": task_config}) # keep the original format and datatype
+ R.set_tags(**{"hostname": socket.gethostname()})
+ recorder: Recorder = R.get_recorder()
+ return recorder
- # model initiaiton
- model = init_instance_by_config(task_config["model"])
- dataset = init_instance_by_config(task_config["dataset"])
- # start exp
- with R.start(experiment_name=experiment_name):
- # train model
- R.log_params(**flatten_dict(task_config))
+def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
+ """
+ Finish task training with real model fitting and saving.
+
+ Args:
+ rec (Recorder): the recorder will be resumed
+ experiment_name (str): the name of experiment
+
+ Returns:
+ Recorder: the model recorder
+ """
+ with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
+ task_config = R.load_object("task")
+ # model & dataset initiation
+ model: Model = init_instance_by_config(task_config["model"])
+ dataset: Dataset = init_instance_by_config(task_config["dataset"])
+ # model training
model.fit(dataset)
- recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
-
+ # this dataset is saved for online inference. So the concrete data should not be dumped
+ dataset.config(dump_all=False, recursive=True)
+ R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
- for record in task_config["record"]:
- if record["class"] == SignalRecord.__name__:
- srconf = {"model": model, "dataset": dataset, "recorder": recorder}
- record["kwargs"].update(srconf)
- sr = init_instance_by_config(record)
- sr.generate()
+ records = task_config.get("record", [])
+ if isinstance(records, dict): # prevent only one dict
+ records = [records]
+ for record in records:
+ cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
+ if cls is SignalRecord:
+ rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
- rconf = {"recorder": recorder}
- record["kwargs"].update(rconf)
- ar = init_instance_by_config(record)
- ar.generate()
+ rconf = {"recorder": rec}
+ r = cls(**kwargs, **rconf)
+ r.generate()
+
+ return rec
+
+
+def task_train(task_config: dict, experiment_name: str) -> Recorder:
+ """
+ Task based training, will be divided into two steps.
+
+ Parameters
+ ----------
+ task_config : dict
+ The config of a task.
+ experiment_name: str
+ The name of experiment
+
+ Returns
+ ----------
+ Recorder: The instance of the recorder
+ """
+ recorder = begin_task_train(task_config, experiment_name)
+ recorder = end_task_train(recorder, experiment_name)
+ return recorder
+
+
+class Trainer:
+ """
+ The trainer can train a list of models.
+ There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
+ """
+
+ def __init__(self):
+ self.delay = False
+
+ def train(self, tasks: list, *args, **kwargs) -> list:
+ """
+ Given a list of task definitions, begin training, and return the models.
+
+ For Trainer, it finishes real training in this method.
+ For DelayTrainer, it only does some preparation in this method.
+
+ Args:
+ tasks: a list of tasks
+
+ Returns:
+ list: a list of models
+ """
+ raise NotImplementedError(f"Please implement the `train` method.")
+
+ def end_train(self, models: list, *args, **kwargs) -> list:
+ """
+ Given a list of models, finished something at the end of training if you need.
+ The models may be Recorder, txt file, database, and so on.
+
+ For Trainer, it does some finishing touches in this method.
+ For DelayTrainer, it finishes real training in this method.
+
+ Args:
+ models: a list of models
+
+ Returns:
+ list: a list of models
+ """
+ # do nothing if you finished all work in `train` method
+ return models
+
+ def is_delay(self) -> bool:
+ """
+ If Trainer will delay finishing `end_train`.
+
+ Returns:
+ bool: if DelayTrainer
+ """
+ return self.delay
+
+
+class TrainerR(Trainer):
+ """
+ Trainer based on (R)ecorder.
+ It will train a list of tasks and return a list of model recorders in a linear way.
+
+ Assumption: models were defined by `task` and the results will be saved to `Recorder`.
+ """
+
+ # Those tag will help you distinguish whether the Recorder has finished traning
+ STATUS_KEY = "train_status"
+ STATUS_BEGIN = "begin_task_train"
+ STATUS_END = "end_task_train"
+
+ def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
+ """
+ Init TrainerR.
+
+ Args:
+ experiment_name (str, optional): the default name of experiment.
+ train_func (Callable, optional): default training method. Defaults to `task_train`.
+ """
+ super().__init__()
+ self.experiment_name = experiment_name
+ self.train_func = train_func
+
+ def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
+ """
+ Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
+
+ Args:
+ tasks (list): a list of definitions based on `task` dict
+ train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
+ experiment_name (str): the experiment name, None for use default name.
+ kwargs: the params for train_func.
+
+ Returns:
+ List[Recorder]: a list of Recorders
+ """
+ if len(tasks) == 0:
+ return []
+ if train_func is None:
+ train_func = self.train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ recs = []
+ for task in tasks:
+ rec = train_func(task, experiment_name, **kwargs)
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
+ recs.append(rec)
+ return recs
+
+ def end_train(self, recs: list, **kwargs) -> List[Recorder]:
+ """
+ Set STATUS_END tag to the recorders.
+
+ Args:
+ recs (list): a list of trained recorders.
+
+ Returns:
+ List[Recorder]: the same list as the param.
+ """
+ for rec in recs:
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
+ return recs
+
+
+class DelayTrainerR(TrainerR):
+ """
+ A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
+ """
+
+ def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
+ """
+ Init TrainerRM.
+
+ Args:
+ experiment_name (str): the default name of experiment.
+ train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
+ end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
+ """
+ super().__init__(experiment_name, train_func)
+ self.end_train_func = end_train_func
+ self.delay = True
+
+ def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
+ """
+ Given a list of Recorder and return a list of trained Recorder.
+ This class will finish real data loading and model fitting.
+
+ Args:
+ recs (list): a list of Recorder, the tasks have been saved to them
+ end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
+ experiment_name (str): the experiment name, None for use default name.
+ kwargs: the params for end_train_func.
+
+ Returns:
+ List[Recorder]: a list of Recorders
+ """
+ if end_train_func is None:
+ end_train_func = self.end_train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ for rec in recs:
+ if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
+ continue
+ end_train_func(rec, experiment_name, **kwargs)
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
+ return recs
+
+
+class TrainerRM(Trainer):
+ """
+ Trainer based on (R)ecorder and Task(M)anager.
+ It can train a list of tasks and return a list of model recorders in a multiprocessing way.
+
+ Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
+ """
+
+ # Those tag will help you distinguish whether the Recorder has finished traning
+ STATUS_KEY = "train_status"
+ STATUS_BEGIN = "begin_task_train"
+ STATUS_END = "end_task_train"
+
+ def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
+ """
+ Init TrainerR.
+
+ Args:
+ experiment_name (str): the default name of experiment.
+ task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
+ train_func (Callable, optional): default training method. Defaults to `task_train`.
+ """
+ super().__init__()
+ self.experiment_name = experiment_name
+ self.task_pool = task_pool
+ self.train_func = train_func
+
+ def train(
+ self,
+ tasks: list,
+ train_func: Callable = None,
+ experiment_name: str = None,
+ before_status: str = TaskManager.STATUS_WAITING,
+ after_status: str = TaskManager.STATUS_DONE,
+ **kwargs,
+ ) -> List[Recorder]:
+ """
+ Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
+
+ This method defaults to a single process, but TaskManager offered a great way to parallel training.
+ Users can customize their train_func to realize multiple processes or even multiple machines.
+
+ Args:
+ tasks (list): a list of definitions based on `task` dict
+ train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
+ experiment_name (str): the experiment name, None for use default name.
+ before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
+ after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
+ kwargs: the params for train_func.
+
+ Returns:
+ List[Recorder]: a list of Recorders
+ """
+ if len(tasks) == 0:
+ return []
+ if train_func is None:
+ train_func = self.train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ task_pool = self.task_pool
+ if task_pool is None:
+ task_pool = experiment_name
+ tm = TaskManager(task_pool=task_pool)
+ _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
+ run_task(
+ train_func,
+ task_pool,
+ experiment_name=experiment_name,
+ before_status=before_status,
+ after_status=after_status,
+ **kwargs,
+ )
+
+ recs = []
+ for _id in _id_list:
+ rec = tm.re_query(_id)["res"]
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
+ recs.append(rec)
+ return recs
+
+ def end_train(self, recs: list, **kwargs) -> List[Recorder]:
+ """
+ Set STATUS_END tag to the recorders.
+
+ Args:
+ recs (list): a list of trained recorders.
+
+ Returns:
+ List[Recorder]: the same list as the param.
+ """
+ for rec in recs:
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
+ return recs
+
+
+class DelayTrainerRM(TrainerRM):
+ """
+ A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
+
+ """
+
+ def __init__(
+ self,
+ experiment_name: str = None,
+ task_pool: str = None,
+ train_func=begin_task_train,
+ end_train_func=end_task_train,
+ ):
+ """
+ Init DelayTrainerRM.
+
+ Args:
+ experiment_name (str): the default name of experiment.
+ task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
+ train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
+ end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
+ """
+ super().__init__(experiment_name, task_pool, train_func)
+ self.end_train_func = end_train_func
+ self.delay = True
+
+ def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
+ """
+ Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
+
+ Args:
+ tasks (list): a list of definition based on `task` dict
+ train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
+ experiment_name (str): the experiment name, None for use default name.
+
+ Returns:
+ List[Recorder]: a list of Recorders
+ """
+ if len(tasks) == 0:
+ return []
+ return super().train(
+ tasks,
+ train_func=train_func,
+ experiment_name=experiment_name,
+ after_status=TaskManager.STATUS_PART_DONE,
+ **kwargs,
+ )
+
+ def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
+ """
+ Given a list of Recorder and return a list of trained Recorder.
+ This class will finish real data loading and model fitting.
+
+ NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
+
+ Args:
+ recs (list): a list of Recorder, the tasks have been saved to them.
+ end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
+ experiment_name (str): the experiment name, None for use default name.
+ kwargs: the params for end_train_func.
+
+ Returns:
+ List[Recorder]: a list of Recorders
+ """
+
+ if end_train_func is None:
+ end_train_func = self.end_train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ task_pool = self.task_pool
+ if task_pool is None:
+ task_pool = experiment_name
+ tasks = []
+ for rec in recs:
+ tasks.append(rec.load_object("task"))
+
+ run_task(
+ end_train_func,
+ task_pool,
+ query={"filter": {"$in": tasks}}, # only train these tasks
+ experiment_name=experiment_name,
+ before_status=TaskManager.STATUS_PART_DONE,
+ **kwargs,
+ )
+ for rec in recs:
+ rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
+ return recs
diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py
index 1ee6f07a1f..77857182d9 100644
--- a/qlib/utils/__init__.py
+++ b/qlib/utils/__init__.py
@@ -6,6 +6,7 @@
from __future__ import print_function
import os
+import pickle
import re
import copy
import json
@@ -24,7 +25,9 @@
import numpy as np
import pandas as pd
from pathlib import Path
-from typing import Union, Tuple, Text, Optional
+from typing import Union, Tuple, Any, Text, Optional
+from types import ModuleType
+from urllib.parse import urlparse
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -165,24 +168,25 @@ def parse_field(field):
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
-def get_module_by_module_path(module_path):
+def get_module_by_module_path(module_path: Union[str, ModuleType]):
"""Load module path
:param module_path:
:return:
"""
-
- if module_path.endswith(".py"):
- module_spec = importlib.util.spec_from_file_location("", module_path)
- module = importlib.util.module_from_spec(module_spec)
- module_spec.loader.exec_module(module)
+ if isinstance(module_path, ModuleType):
+ module = module_path
else:
- module = importlib.import_module(module_path)
-
+ if module_path.endswith(".py"):
+ module_spec = importlib.util.spec_from_file_location("", module_path)
+ module = importlib.util.module_from_spec(module_spec)
+ module_spec.loader.exec_module(module)
+ else:
+ module = importlib.import_module(module_path)
return module
-def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
+def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
@@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
config : [dict, str]
similar to config
- module : Python module
+ default_module : Python module or str
It should be a python module to load the class type
+ This function will load class from the config['module_path'] first.
+ If config['module_path'] doesn't exists, it will load the class from default_module.
Returns
-------
@@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
the class object and it's arguments.
"""
if isinstance(config, dict):
+ module = get_module_by_module_path(config.get("module_path", default_module))
+
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
+ module = get_module_by_module_path(default_module)
+
klass = getattr(module, config)
kwargs = {}
else:
@@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def init_instance_by_config(
- config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
-) -> object:
+ config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
+) -> Any:
"""
get initialized instance with config
@@ -227,13 +237,19 @@ def init_instance_by_config(
'model_path': path, # It is optional if module is given
}
str example.
- "ClassName": getattr(module, config)() will be used.
+ 1) specify a pickle object
+ - path like 'file:////obj.pkl'
+ 2) specify a class name
+ - "ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
- module : Python module
+ default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
+ This function will load class from the config['module_path'] first.
+ If config['module_path'] doesn't exists, it will load the class from default_module.
+
accept_types: Union[type, Tuple[type]]
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
@@ -246,10 +262,14 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
- if module is None:
- module = get_module_by_module_path(config["module_path"])
+ if isinstance(config, str):
+ # path like 'file:////obj.pkl'
+ pr = urlparse(config)
+ if pr.scheme == "file":
+ with open(os.path.join(pr.netloc, pr.path), "rb") as f:
+ return pickle.load(f)
- klass, cls_kwargs = get_cls_kwargs(config, module)
+ klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
@@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
-def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
+def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
"""
from qlib.data import D
- cal = D.calendar(future=future)
+ cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
@@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
return df.sort_index(axis=axis)
-def flatten_dict(d, parent_key="", sep="."):
- """flatten_dict.
+FLATTEN_TUPLE = "_FLATTEN_TUPLE"
+
+
+def flatten_dict(d, parent_key="", sep=".") -> dict:
+ """
+ Flatten a nested dict.
+
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
- Parameters
- ----------
- d :
- d
- parent_key :
- parent_key
- sep :
- sep
+ >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
+ >>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
+
+ Args:
+ d (dict): the dict waiting for flatting
+ parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
+ sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
+
+ Returns:
+ dict: flatten dict
"""
items = []
for k, v in d.items():
- new_key = parent_key + sep + k if parent_key else k
+ if sep == FLATTEN_TUPLE:
+ new_key = (parent_key, k) if parent_key else k
+ else:
+ new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py
index 4bc57eccd5..263e632deb 100644
--- a/qlib/utils/serial.py
+++ b/qlib/utils/serial.py
@@ -3,16 +3,24 @@
from pathlib import Path
import pickle
+import typing
+import dill
+from typing import Union
class Serializable:
"""
- Serializable behaves like pickle.
- But it only saves the state whose name **does not** start with `_`
+ Serializable will change the behaviors of pickle.
+ - It only saves the state whose name **does not** start with `_`
+ It provides a syntactic sugar for distinguish the attributes which user doesn't want.
+ - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
"""
+ pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
+ default_dump_all = False # if dump all things
+
def __init__(self):
- self._dump_all = False
+ self._dump_all = self.default_dump_all
self._exclude = []
def __getstate__(self) -> dict:
@@ -33,18 +41,86 @@ def dump_all(self):
@property
def exclude(self):
"""
- What attribute will be dumped
+ What attribute will not be dumped
"""
return getattr(self, "_exclude", [])
- def config(self, dump_all: bool = None, exclude: list = None):
- if dump_all is not None:
- self._dump_all = dump_all
+ FLAG_KEY = "_qlib_serial_flag"
+
+ def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
+ """
+ configure the serializable object
+
+ Parameters
+ ----------
+ dump_all : bool
+ will the object dump all object
+ exclude : list
+ What attribute will not be dumped
+ recursive : bool
+ will the configuration be recursive
+ """
+
+ params = {"dump_all": dump_all, "exclude": exclude}
+
+ for k, v in params.items():
+ if v is not None:
+ attr_name = f"_{k}"
+ setattr(self, attr_name, v)
- if exclude is not None:
- self._exclude = exclude
+ if recursive:
+ for obj in self.__dict__.values():
+ # set flag to prevent endless loop
+ self.__dict__[self.FLAG_KEY] = True
+ if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
+ obj.config(**params, recursive=True)
+ del self.__dict__[self.FLAG_KEY]
- def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
+ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
+ """
+ Dump self to a pickle file.
+
+ Args:
+ path (Union[Path, str]): the path to dump
+ dump_all (bool, optional): if need to dump all things. Defaults to None.
+ exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
+ """
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
- pickle.dump(self, f)
+ self.get_backend().dump(self, f)
+
+ @classmethod
+ def load(cls, filepath):
+ """
+ Load the collector from a filepath.
+
+ Args:
+ filepath (str): the path of file
+
+ Raises:
+ TypeError: the pickled file must be `Collector`
+
+ Returns:
+ Collector: the instance of Collector
+ """
+ with open(filepath, "rb") as f:
+ object = cls.get_backend().load(f)
+ if isinstance(object, cls):
+ return object
+ else:
+ raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
+
+ @classmethod
+ def get_backend(cls):
+ """
+ Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
+
+ Returns:
+ module: pickle or dill module based on pickle_backend
+ """
+ if cls.pickle_backend == "pickle":
+ return pickle
+ elif cls.pickle_backend == "dill":
+ return dill
+ else:
+ raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py
index 8135bab60a..2b2535edca 100644
--- a/qlib/workflow/__init__.py
+++ b/qlib/workflow/__init__.py
@@ -331,7 +331,7 @@ def set_uri(self, uri: Optional[Text]):
"""
self.exp_manager.set_uri(uri)
- def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
+ def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
"""
Method for retrieving a recorder.
diff --git a/qlib/workflow/online/__init__.py b/qlib/workflow/online/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py
new file mode 100644
index 0000000000..443cd61ad8
--- /dev/null
+++ b/qlib/workflow/online/manager.py
@@ -0,0 +1,304 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
+
+With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
+In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
+So this module provides a series of methods to control this process.
+
+This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
+Which means you can verify your strategy or find a better one.
+
+There are 4 total situations for using different trainers in different situations:
+
+
+
+========================= ===================================================================================
+Situations Description
+========================= ===================================================================================
+Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
+
+Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
+ in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
+
+Simulation + Trainer When your models have some temporal dependence on the previous models, then you
+ need to consider using Trainer. This means it will REAL train your models in
+ every routine and prepare signals for every routine.
+
+Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
+ for the ability to multitasking. It means all tasks in all routines
+ can be REAL trained at the end of simulating. The signals will be prepared well at
+ different time segments (based on whether or not any new model is online).
+========================= ===================================================================================
+"""
+
+import logging
+from typing import Callable, Dict, List, Union
+
+import pandas as pd
+from qlib import get_module_logger
+from qlib.data.data import D
+from qlib.log import set_global_logger_level
+from qlib.model.ens.ensemble import AverageEnsemble
+from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
+from qlib.utils import flatten_dict
+from qlib.utils.serial import Serializable
+from qlib.workflow.online.strategy import OnlineStrategy
+from qlib.workflow.task.collect import MergeCollector
+
+
+class OnlineManager(Serializable):
+ """
+ OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
+ It also provides a history recording of which models are online at what time.
+ """
+
+ STATUS_SIMULATING = "simulating" # when calling `simulate`
+ STATUS_NORMAL = "normal" # the normal status
+
+ def __init__(
+ self,
+ strategies: Union[OnlineStrategy, List[OnlineStrategy]],
+ trainer: Trainer = None,
+ begin_time: Union[str, pd.Timestamp] = None,
+ freq="day",
+ ):
+ """
+ Init OnlineManager.
+ One OnlineManager must have at least one OnlineStrategy.
+
+ Args:
+ strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
+ begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
+ trainer (Trainer): the trainer to train task. None for using TrainerR.
+ freq (str, optional): data frequency. Defaults to "day".
+ """
+ self.logger = get_module_logger(self.__class__.__name__)
+ if not isinstance(strategies, list):
+ strategies = [strategies]
+ self.strategies = strategies
+ self.freq = freq
+ if begin_time is None:
+ begin_time = D.calendar(freq=self.freq).max()
+ self.begin_time = pd.Timestamp(begin_time)
+ self.cur_time = self.begin_time
+ # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
+ self.history = {}
+ if trainer is None:
+ trainer = TrainerR()
+ self.trainer = trainer
+ self.signals = None
+ self.status = self.STATUS_NORMAL
+
+ def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
+ """
+ Get tasks from every strategy's first_tasks method and train them.
+ If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
+
+ Args:
+ strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
+ model_kwargs (dict): the params for `prepare_online_models`
+ """
+ if strategies is None:
+ strategies = self.strategies
+ for strategy in strategies:
+
+ self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
+ tasks = strategy.first_tasks()
+ models = self.trainer.train(tasks, experiment_name=strategy.name_id)
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+ self.logger.info(f"Finished training {len(models)} models.")
+
+ online_models = strategy.prepare_online_models(models, **model_kwargs)
+ self.history.setdefault(self.cur_time, {})[strategy] = online_models
+
+ def routine(
+ self,
+ cur_time: Union[str, pd.Timestamp] = None,
+ task_kwargs: dict = {},
+ model_kwargs: dict = {},
+ signal_kwargs: dict = {},
+ ):
+ """
+ Typical update process for every strategy and record the online history.
+
+ The typical update process after a routine, such as day by day or month by month.
+ The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
+
+ If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
+
+ Args:
+ cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
+ task_kwargs (dict): the params for `prepare_tasks`
+ model_kwargs (dict): the params for `prepare_online_models`
+ signal_kwargs (dict): the params for `prepare_signals`
+ """
+ if cur_time is None:
+ cur_time = D.calendar(freq=self.freq).max()
+ self.cur_time = pd.Timestamp(cur_time) # None for latest date
+
+ for strategy in self.strategies:
+ self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
+ if self.status == self.STATUS_NORMAL:
+ strategy.tool.update_online_pred()
+
+ tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
+ models = self.trainer.train(tasks)
+ if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+ self.logger.info(f"Finished training {len(models)} models.")
+ online_models = strategy.prepare_online_models(models, **model_kwargs)
+ self.history.setdefault(self.cur_time, {})[strategy] = online_models
+
+ if not self.trainer.is_delay():
+ self.prepare_signals(**signal_kwargs)
+
+ def get_collector(self) -> MergeCollector:
+ """
+ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
+ This collector can be a basis as the signals preparation.
+
+ Returns:
+ MergeCollector: the collector to merge other collectors.
+ """
+ collector_dict = {}
+ for strategy in self.strategies:
+ collector_dict[strategy.name_id] = strategy.get_collector()
+ return MergeCollector(collector_dict, process_list=[])
+
+ def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
+ """
+ Add some new strategies to OnlineManager.
+
+ Args:
+ strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
+ """
+ if not isinstance(strategies, list):
+ strategies = [strategies]
+ self.first_train(strategies)
+ self.strategies.extend(strategies)
+
+ def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
+ """
+ After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
+
+ NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
+
+ Even if the latest signal already exists, the latest calculation result will be overwritten.
+
+ .. note::
+
+ Given a prediction of a certain time, all signals before this time will be prepared well.
+
+ Args:
+ prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
+ over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
+
+ Returns:
+ pd.DataFrame: the signals.
+ """
+ signals = prepare_func(self.get_collector()())
+ old_signals = self.signals
+ if old_signals is not None and not over_write:
+ old_max = old_signals.index.get_level_values("datetime").max()
+ new_signals = signals.loc[old_max:]
+ signals = pd.concat([old_signals, new_signals], axis=0)
+ else:
+ new_signals = signals
+ self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
+ self.signals = signals
+ return new_signals
+
+ def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
+ """
+ Get prepared online signals.
+
+ Returns:
+ Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
+ pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
+ """
+ return self.signals
+
+ SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information
+ SIM_LOG_NAME = "SIMULATE_INFO"
+
+ def simulate(
+ self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
+ ) -> Union[pd.Series, pd.DataFrame]:
+ """
+ Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
+
+ Considering the parallel training, the models and signals can be prepared after all routine simulating.
+
+ The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
+
+ Args:
+ end_time: the time the simulation will end
+ frequency: the calendar frequency
+ task_kwargs (dict): the params for `prepare_tasks`
+ model_kwargs (dict): the params for `prepare_online_models`
+ signal_kwargs (dict): the params for `prepare_signals`
+
+ Returns:
+ Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
+ pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
+ """
+ self.status = self.STATUS_SIMULATING
+ cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
+ self.first_train()
+
+ simulate_level = self.SIM_LOG_LEVEL
+ set_global_logger_level(simulate_level)
+ logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
+
+ for cur_time in cal:
+ self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
+ self.routine(
+ cur_time,
+ task_kwargs=task_kwargs,
+ model_kwargs=model_kwargs,
+ signal_kwargs=signal_kwargs,
+ )
+ # delay prepare the models and signals
+ if self.trainer.is_delay():
+ self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
+
+ # FIXME: get logging level firstly and restore it here
+ set_global_logger_level(logging.DEBUG)
+ self.logger.info(f"Finished preparing signals")
+ self.status = self.STATUS_NORMAL
+ return self.get_signals()
+
+ def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
+ """
+ Prepare all models and signals if something is waiting for preparation.
+
+ Args:
+ model_kwargs: the params for `end_train`
+ signal_kwargs: the params for `prepare_signals`
+ """
+ last_models = {}
+ signals_time = D.calendar()[0]
+ need_prepare = False
+ for cur_time, strategy_models in self.history.items():
+ self.cur_time = cur_time
+
+ for strategy, models in strategy_models.items():
+ # only new online models need to prepare
+ if last_models.setdefault(strategy, set()) != set(models):
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
+ strategy.tool.reset_online_tag(models)
+ need_prepare = True
+ last_models[strategy] = set(models)
+
+ if need_prepare:
+ # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
+ self.prepare_signals(**signal_kwargs)
+ if signals_time > cur_time:
+ self.logger.warn(
+ f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
+ )
+ need_prepare = False
+ signals_time = self.signals.index.get_level_values("datetime").max()
diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py
new file mode 100644
index 0000000000..a54eb32bfe
--- /dev/null
+++ b/qlib/workflow/online/strategy.py
@@ -0,0 +1,211 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+OnlineStrategy module is an element of online serving.
+"""
+
+from copy import deepcopy
+from typing import List, Tuple, Union
+from qlib.data.data import D
+from qlib.log import get_module_logger
+from qlib.model.ens.group import RollingGroup
+from qlib.workflow.online.utils import OnlineTool, OnlineToolR
+from qlib.workflow.recorder import Recorder
+from qlib.workflow.task.collect import Collector, RecorderCollector
+from qlib.workflow.task.gen import RollingGen, task_generator
+from qlib.workflow.task.utils import TimeAdjuster
+
+
+class OnlineStrategy:
+ """
+ OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
+ """
+
+ def __init__(self, name_id: str):
+ """
+ Init OnlineStrategy.
+ This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
+
+ Args:
+ name_id (str): a unique name or id.
+ trainer (Trainer, optional): a instance of Trainer. Defaults to None.
+ """
+ self.name_id = name_id
+ self.logger = get_module_logger(self.__class__.__name__)
+ self.tool = OnlineTool()
+
+ def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
+ """
+ After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
+ Return the new tasks waiting for training.
+
+ You can find the last online models by OnlineTool.online_models.
+ """
+ raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
+
+ def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
+ """
+ Select some models from trained models and set them to online models.
+ This is a typical implementation to online all trained models, you can override it to implement the complex method.
+ You can find the last online models by OnlineTool.online_models if you still need them.
+
+ NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
+
+ Args:
+ models (list): a list of models.
+ cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
+
+ Returns:
+ List[object]: a list of online models.
+ """
+ if not trained_models:
+ return self.tool.online_models()
+ self.tool.reset_online_tag(trained_models)
+ return trained_models
+
+ def first_tasks(self) -> List[dict]:
+ """
+ Generate a series of tasks firstly and return them.
+ """
+ raise NotImplementedError(f"Please implement the `first_tasks` method.")
+
+ def get_collector(self) -> Collector:
+ """
+ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
+
+ For example:
+ 1) collect predictions in Recorder
+ 2) collect signals in a txt file
+
+ Returns:
+ Collector
+ """
+ raise NotImplementedError(f"Please implement the `get_collector` method.")
+
+
+class RollingStrategy(OnlineStrategy):
+
+ """
+ This example strategy always uses the latest rolling model sas online models.
+ """
+
+ def __init__(
+ self,
+ name_id: str,
+ task_template: Union[dict, List[dict]],
+ rolling_gen: RollingGen,
+ ):
+ """
+ Init RollingStrategy.
+
+ Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
+
+ Args:
+ name_id (str): a unique name or id. Will be also the name of the Experiment.
+ task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
+ rolling_gen (RollingGen): an instance of RollingGen
+ """
+ super().__init__(name_id=name_id)
+ self.exp_name = self.name_id
+ if not isinstance(task_template, list):
+ task_template = [task_template]
+ self.task_template = task_template
+ self.rg = rolling_gen
+ self.tool = OnlineToolR(self.exp_name)
+ self.ta = TimeAdjuster()
+
+ def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
+ """
+ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
+
+ Assumption: the models can be distinguished based on the model name and rolling test segments.
+ If you do not want this assumption, please implement your method or use another rec_key_func.
+
+ Args:
+ rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
+ rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
+ artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
+ """
+
+ def rec_key(recorder):
+ task_config = recorder.load_object("task")
+ model_key = task_config["model"]["class"]
+ rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
+ return model_key, rolling_key
+
+ if rec_key_func is None:
+ rec_key_func = rec_key
+
+ artifacts_collector = RecorderCollector(
+ experiment=self.exp_name,
+ process_list=process_list,
+ rec_key_func=rec_key_func,
+ rec_filter_func=rec_filter_func,
+ artifacts_key=artifacts_key,
+ )
+
+ return artifacts_collector
+
+ def first_tasks(self) -> List[dict]:
+ """
+ Use rolling_gen to generate different tasks based on task_template.
+
+ Returns:
+ List[dict]: a list of tasks
+ """
+ return task_generator(
+ tasks=self.task_template,
+ generators=self.rg, # generate different date segment
+ )
+
+ def prepare_tasks(self, cur_time) -> List[dict]:
+ """
+ Prepare new tasks based on cur_time (None for the latest).
+
+ You can find the last online models by OnlineToolR.online_models.
+
+ Returns:
+ List[dict]: a list of new tasks.
+ """
+ latest_records, max_test = self._list_latest(self.tool.online_models())
+ if max_test is None:
+ self.logger.warn(f"No latest online recorders, no new tasks.")
+ return []
+ calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
+ self.logger.info(
+ f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
+ )
+ if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
+ old_tasks = []
+ tasks_tmp = []
+ for rec in latest_records:
+ task = rec.load_object("task")
+ old_tasks.append(deepcopy(task))
+ test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
+ # modify the test segment to generate new tasks
+ task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
+ tasks_tmp.append(task)
+ new_tasks_tmp = task_generator(tasks_tmp, self.rg)
+ new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
+ return new_tasks
+ return []
+
+ def _list_latest(self, rec_list: List[Recorder]):
+ """
+ List latest recorder form rec_list
+
+ Args:
+ rec_list (List[Recorder]): a list of Recorder
+
+ Returns:
+ List[Recorder], pd.Timestamp: the latest recorders and their test end time
+ """
+ if len(rec_list) == 0:
+ return rec_list, None
+ max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
+ latest_rec = []
+ for rec in rec_list:
+ if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
+ latest_rec.append(rec)
+ return latest_rec, max_test
diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py
new file mode 100644
index 0000000000..561f7e18ae
--- /dev/null
+++ b/qlib/workflow/online/update.py
@@ -0,0 +1,160 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Updater is a module to update artifacts such as predictions when the stock data is updating.
+"""
+
+from abc import ABCMeta, abstractmethod
+
+import pandas as pd
+from qlib import get_module_logger
+from qlib.data import D
+from qlib.data.dataset import DatasetH
+from qlib.data.dataset.handler import DataHandlerLP
+from qlib.model import Model
+from qlib.utils import get_date_by_shift
+from qlib.workflow.recorder import Recorder
+
+
+class RMDLoader:
+ """
+ Recorder Model Dataset Loader
+ """
+
+ def __init__(self, rec: Recorder):
+ self.rec = rec
+
+ def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
+ """
+ Load, config and setup dataset.
+
+ This dataset is for inference.
+
+ Args:
+ start_time :
+ the start_time of underlying data
+ end_time :
+ the end_time of underlying data
+ segments : dict
+ the segments config for dataset
+ Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
+
+ Returns:
+ DatasetH: the instance of DatasetH
+
+ """
+ if segments is None:
+ segments = {"test": (start_time, end_time)}
+ dataset: DatasetH = self.rec.load_object("dataset")
+ dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
+ dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
+ return dataset
+
+ def get_model(self) -> Model:
+ return self.rec.load_object("params.pkl")
+
+
+class RecordUpdater(metaclass=ABCMeta):
+ """
+ Update a specific recorders
+ """
+
+ def __init__(self, record: Recorder, *args, **kwargs):
+ self.record = record
+ self.logger = get_module_logger(self.__class__.__name__)
+
+ @abstractmethod
+ def update(self, *args, **kwargs):
+ """
+ Update info for specific recorder
+ """
+ ...
+
+
+class PredUpdater(RecordUpdater):
+ """
+ Update the prediction in the Recorder
+ """
+
+ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
+ """
+ Init PredUpdater.
+
+ Args:
+ record : Recorder
+ to_date :
+ update to prediction to the `to_date`
+ hist_ref : int
+ Sometimes, the dataset will have historical depends.
+ Leave the problem to users to set the length of historical dependency
+
+ .. note::
+
+ the start_time is not included in the hist_ref
+
+ """
+ # TODO: automate this hist_ref in the future.
+ super().__init__(record=record)
+
+ self.to_date = to_date
+ self.hist_ref = hist_ref
+ self.freq = freq
+ self.rmdl = RMDLoader(rec=record)
+
+ if to_date == None:
+ to_date = D.calendar(freq=freq)[-1]
+ self.to_date = pd.Timestamp(to_date)
+ self.old_pred = record.load_object("pred.pkl")
+ self.last_end = self.old_pred.index.get_level_values("datetime").max()
+
+ def prepare_data(self) -> DatasetH:
+ """
+ Load dataset
+
+ Separating this function will make it easier to reuse the dataset
+
+ Returns:
+ DatasetH: the instance of DatasetH
+ """
+ start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
+ start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
+ seg = {"test": (start_time, self.to_date)}
+ dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
+ return dataset
+
+ def update(self, dataset: DatasetH = None):
+ """
+ Update the prediction in a recorder.
+
+ Args:
+ DatasetH: the instance of DatasetH. None for reprepare.
+ """
+ # FIXME: the problem below is not solved
+ # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
+ # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
+ # https://github.com/pytorch/pytorch/issues/16797
+
+ start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
+ if start_time >= self.to_date:
+ self.logger.info(
+ f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
+ )
+ return
+
+ # load dataset
+ if dataset is None:
+ # For reusing the dataset
+ dataset = self.prepare_data()
+
+ # Load model
+ model = self.rmdl.get_model()
+
+ new_pred: pd.Series = model.predict(dataset)
+
+ cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
+ cb_pred = cb_pred.sort_index()
+
+ self.record.save_objects(**{"pred.pkl": cb_pred})
+
+ self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py
new file mode 100644
index 0000000000..f3ef13aa93
--- /dev/null
+++ b/qlib/workflow/online/utils.py
@@ -0,0 +1,168 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+OnlineTool is a module to set and unset a series of `online` models.
+The `online` models are some decisive models in some time points, which can be changed with the change of time.
+This allows us to use efficient submodels as the market-style changing.
+"""
+
+from typing import List, Union
+
+from qlib.log import get_module_logger
+from qlib.workflow.online.update import PredUpdater
+from qlib.workflow.recorder import Recorder
+from qlib.workflow.task.utils import list_recorders
+
+
+class OnlineTool:
+ """
+ OnlineTool will manage `online` models in an experiment that includes the model recorders.
+ """
+
+ ONLINE_KEY = "online_status" # the online status key in recorder
+ ONLINE_TAG = "online" # the 'online' model
+ OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
+
+ def __init__(self):
+ """
+ Init OnlineTool.
+ """
+ self.logger = get_module_logger(self.__class__.__name__)
+
+ def set_online_tag(self, tag, recorder: Union[list, object]):
+ """
+ Set `tag` to the model to sign whether online.
+
+ Args:
+ tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
+ recorder (Union[list,object]): the model's recorder
+ """
+ raise NotImplementedError(f"Please implement the `set_online_tag` method.")
+
+ def get_online_tag(self, recorder: object) -> str:
+ """
+ Given a model recorder and return its online tag.
+
+ Args:
+ recorder (Object): the model's recorder
+
+ Returns:
+ str: the online tag
+ """
+ raise NotImplementedError(f"Please implement the `get_online_tag` method.")
+
+ def reset_online_tag(self, recorder: Union[list, object]):
+ """
+ Offline all models and set the recorders to 'online'.
+
+ Args:
+ recorder (Union[list,object]):
+ the recorder you want to reset to 'online'.
+
+ """
+ raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
+
+ def online_models(self) -> list:
+ """
+ Get current `online` models
+
+ Returns:
+ list: a list of `online` models.
+ """
+ raise NotImplementedError(f"Please implement the `online_models` method.")
+
+ def update_online_pred(self, to_date=None):
+ """
+ Update the predictions of `online` models to to_date.
+
+ Args:
+ to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
+
+ """
+ raise NotImplementedError(f"Please implement the `update_online_pred` method.")
+
+
+class OnlineToolR(OnlineTool):
+ """
+ The implementation of OnlineTool based on (R)ecorder.
+ """
+
+ def __init__(self, experiment_name: str):
+ """
+ Init OnlineToolR.
+
+ Args:
+ experiment_name (str): the experiment name.
+ """
+ super().__init__()
+ self.exp_name = experiment_name
+
+ def set_online_tag(self, tag, recorder: Union[Recorder, List]):
+ """
+ Set `tag` to the model's recorder to sign whether online.
+
+ Args:
+ tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
+ recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
+ """
+ if isinstance(recorder, Recorder):
+ recorder = [recorder]
+ for rec in recorder:
+ rec.set_tags(**{self.ONLINE_KEY: tag})
+ self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
+
+ def get_online_tag(self, recorder: Recorder) -> str:
+ """
+ Given a model recorder and return its online tag.
+
+ Args:
+ recorder (Recorder): an instance of recorder
+
+ Returns:
+ str: the online tag
+ """
+ tags = recorder.list_tags()
+ return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
+
+ def reset_online_tag(self, recorder: Union[Recorder, List]):
+ """
+ Offline all models and set the recorders to 'online'.
+
+ Args:
+ recorder (Union[Recorder, List]):
+ the recorder you want to reset to 'online'.
+
+ """
+ if isinstance(recorder, Recorder):
+ recorder = [recorder]
+ recs = list_recorders(self.exp_name)
+ self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
+ self.set_online_tag(self.ONLINE_TAG, recorder)
+
+ def online_models(self) -> list:
+ """
+ Get current `online` models
+
+ Returns:
+ list: a list of `online` models.
+ """
+ return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
+
+ def update_online_pred(self, to_date=None):
+ """
+ Update the predictions of online models to to_date.
+
+ Args:
+ to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
+ """
+ online_models = self.online_models()
+ for rec in online_models:
+ hist_ref = 0
+ task = rec.load_object("task")
+ # Special treatment of historical dependencies
+ if task["dataset"]["class"] == "TSDatasetH":
+ hist_ref = task["dataset"]["kwargs"]["step_len"]
+ PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
+
+ self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py
index 5732c95a9e..fc71b3f9a2 100644
--- a/qlib/workflow/record_temp.py
+++ b/qlib/workflow/record_temp.py
@@ -151,6 +151,10 @@ def generate(self, **kwargs):
del params["data_key"]
# The backend handler should be DataHandler
raw_label = self.dataset.prepare(**params)
+ except AttributeError:
+ # The data handler is initialize with `drop_raw=True`...
+ # So raw_label is not available
+ raw_label = None
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls
@@ -236,6 +240,9 @@ def generate(self, **kwargs):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
+ if label is None or not isinstance(label, pd.DataFrame) or label.empty:
+ logger.warn(f"Empty label.")
+ return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py
index b9b2fd1b36..0c9abf7318 100644
--- a/qlib/workflow/recorder.py
+++ b/qlib/workflow/recorder.py
@@ -39,6 +39,9 @@ def __repr__(self):
def __str__(self):
return str(self.info)
+ def __hash__(self) -> int:
+ return hash(self.info["id"])
+
@property
def info(self):
output = dict()
@@ -232,6 +235,14 @@ def __repr__(self):
client=self.client,
)
+ def __hash__(self) -> int:
+ return hash(self.info["id"])
+
+ def __eq__(self, o: object) -> bool:
+ if isinstance(o, MLflowRecorder):
+ return self.info["id"] == o.info["id"]
+ return False
+
@property
def uri(self):
return self._uri
diff --git a/qlib/workflow/task/__init__.py b/qlib/workflow/task/__init__.py
new file mode 100644
index 0000000000..cc338cca4d
--- /dev/null
+++ b/qlib/workflow/task/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""
+Task related workflow is implemented in this folder
+
+A typical task workflow
+
+| Step | Description |
+|-----------------------+------------------------------------------------|
+| TaskGen | Generating tasks. |
+| TaskManager(optional) | Manage generated tasks |
+| run task | retrive tasks from TaskManager and run tasks. |
+"""
diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py
new file mode 100644
index 0000000000..9410c2b9c2
--- /dev/null
+++ b/qlib/workflow/task/collect.py
@@ -0,0 +1,219 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
+"""
+
+from typing import Callable, Dict, List
+from qlib.utils.serial import Serializable
+from qlib.workflow import R
+
+
+class Collector(Serializable):
+ """The collector to collect different results"""
+
+ pickle_backend = "dill" # use dill to dump user method
+
+ def __init__(self, process_list=[]):
+ """
+ Init Collector.
+
+ Args:
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
+ """
+ if not isinstance(process_list, list):
+ process_list = [process_list]
+ self.process_list = process_list
+
+ def collect(self) -> dict:
+ """
+ Collect the results and return a dict like {key: things}
+
+ Returns:
+ dict: the dict after collecting.
+
+ For example:
+
+ {"prediction": pd.Series}
+
+ {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
+
+ ......
+ """
+ raise NotImplementedError(f"Please implement the `collect` method.")
+
+ @staticmethod
+ def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
+ """
+ Do a series of processing to the dict returned by collect and return a dict like {key: things}
+ For example, you can group and ensemble.
+
+ Args:
+ collected_dict (dict): the dict return by `collect`
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
+ The processor order is the same as the list order.
+ For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
+
+ Returns:
+ dict: the dict after processing.
+ """
+ if not isinstance(process_list, list):
+ process_list = [process_list]
+ result = {}
+ for artifact in collected_dict:
+ value = collected_dict[artifact]
+ for process in process_list:
+ if not callable(process):
+ raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
+ value = process(value, *args, **kwargs)
+ result[artifact] = value
+ return result
+
+ def __call__(self, *args, **kwargs) -> dict:
+ """
+ Do the workflow including ``collect`` and ``process_collect``
+
+ Returns:
+ dict: the dict after collecting and processing.
+ """
+ collected = self.collect()
+ return self.process_collect(collected, self.process_list, *args, **kwargs)
+
+
+class MergeCollector(Collector):
+ """
+ A collector to collect the results of other Collectors
+
+ For example:
+
+ We have 2 collector, which named A and B.
+ A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
+ Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
+
+ ......
+
+ """
+
+ def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
+ """
+ Init MergeCollector.
+
+ Args:
+ collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
+ process_list (List[Callable]): the list of processors or the instance of processor to process dict.
+ merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
+ None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
+ """
+ super().__init__(process_list=process_list)
+ self.collector_dict = collector_dict
+ self.merge_func = merge_func
+
+ def collect(self) -> dict:
+ """
+ Collect all results of collector_dict and change the outermost key to a recombination key.
+
+ Returns:
+ dict: the dict after collecting.
+ """
+ collect_dict = {}
+ for collector_key, collector in self.collector_dict.items():
+ tmp_dict = collector()
+ for key, value in tmp_dict.items():
+ if self.merge_func is not None:
+ collect_dict[self.merge_func(collector_key, key)] = value
+ else:
+ collect_dict[(collector_key, key)] = value
+ return collect_dict
+
+
+class RecorderCollector(Collector):
+ ART_KEY_RAW = "__raw"
+
+ def __init__(
+ self,
+ experiment,
+ process_list=[],
+ rec_key_func=None,
+ rec_filter_func=None,
+ artifacts_path={"pred": "pred.pkl"},
+ artifacts_key=None,
+ ):
+ """
+ Init RecorderCollector.
+
+ Args:
+ experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
+ rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
+ rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
+ artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
+ artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
+ """
+ super().__init__(process_list=process_list)
+ if isinstance(experiment, str):
+ experiment = R.get_exp(experiment_name=experiment)
+ self.experiment = experiment
+ self.artifacts_path = artifacts_path
+ if rec_key_func is None:
+ rec_key_func = lambda rec: rec.info["id"]
+ if artifacts_key is None:
+ artifacts_key = list(self.artifacts_path.keys())
+ self.rec_key_func = rec_key_func
+ self.artifacts_key = artifacts_key
+ self.rec_filter_func = rec_filter_func
+
+ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
+ """
+ Collect different artifacts based on recorder after filtering.
+
+ Args:
+ artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
+ rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
+ only_exist (bool, optional): if only collect the artifacts when a recorder really has.
+ If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
+
+ Returns:
+ dict: the dict after collected like {artifact: {rec_key: object}}
+ """
+ if artifacts_key is None:
+ artifacts_key = self.artifacts_key
+ if rec_filter_func is None:
+ rec_filter_func = self.rec_filter_func
+
+ if isinstance(artifacts_key, str):
+ artifacts_key = [artifacts_key]
+
+ collect_dict = {}
+ # filter records
+ recs = self.experiment.list_recorders()
+ recs_flt = {}
+ for rid, rec in recs.items():
+ if rec_filter_func is None or rec_filter_func(rec):
+ recs_flt[rid] = rec
+
+ for _, rec in recs_flt.items():
+ rec_key = self.rec_key_func(rec)
+ for key in artifacts_key:
+ if self.ART_KEY_RAW == key:
+ artifact = rec
+ else:
+ try:
+ artifact = rec.load_object(self.artifacts_path[key])
+ except Exception as e:
+ if only_exist:
+ # only collect existing artifact
+ continue
+ raise e
+ collect_dict.setdefault(key, {})[rec_key] = artifact
+
+ return collect_dict
+
+ def get_exp_name(self) -> str:
+ """
+ Get experiment name
+
+ Returns:
+ str: experiment name
+ """
+ return self.experiment.name
diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py
new file mode 100644
index 0000000000..cdebf50494
--- /dev/null
+++ b/qlib/workflow/task/gen.py
@@ -0,0 +1,231 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+"""
+TaskGenerator module can generate many tasks based on TaskGen and some task templates.
+"""
+import abc
+import copy
+from typing import List, Union, Callable
+from .utils import TimeAdjuster
+
+
+def task_generator(tasks, generators) -> list:
+ """
+ Use a list of TaskGen and a list of task templates to generate different tasks.
+
+ For examples:
+
+ There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
+ task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
+
+ Parameters
+ ----------
+ tasks : List[dict] or dict
+ a list of task templates or a single task
+ generators : List[TaskGen] or TaskGen
+ a list of TaskGen or a single TaskGen
+
+ Returns
+ -------
+ list
+ a list of tasks
+ """
+
+ if isinstance(tasks, dict):
+ tasks = [tasks]
+ if isinstance(generators, TaskGen):
+ generators = [generators]
+
+ # generate gen_task_list
+ for gen in generators:
+ new_task_list = []
+ for task in tasks:
+ new_task_list.extend(gen.generate(task))
+ tasks = new_task_list
+
+ return tasks
+
+
+class TaskGen(metaclass=abc.ABCMeta):
+ """
+ The base class for generating different tasks
+
+ Example 1:
+
+ input: a specific task template and rolling steps
+
+ output: rolling version of the tasks
+
+ Example 2:
+
+ input: a specific task template and losses list
+
+ output: a set of tasks with different losses
+
+ """
+
+ @abc.abstractmethod
+ def generate(self, task: dict) -> List[dict]:
+ """
+ Generate different tasks based on a task template
+
+ Parameters
+ ----------
+ task: dict
+ a task template
+
+ Returns
+ -------
+ typing.List[dict]:
+ A list of tasks
+ """
+ pass
+
+ def __call__(self, *args, **kwargs):
+ """
+ This is just a syntactic sugar for generate
+ """
+ return self.generate(*args, **kwargs)
+
+
+def handler_mod(task: dict, rolling_gen):
+ """
+ Help to modify the handler end time when using RollingGen
+
+ Args:
+ task (dict): a task template
+ rg (RollingGen): an instance of RollingGen
+ """
+ try:
+ interval = rolling_gen.ta.cal_interval(
+ task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
+ task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
+ )
+ # if end_time < the end of test_segments, then change end_time to allow load more data
+ if interval < 0:
+ task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
+ task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
+ )
+ except KeyError:
+ # Maybe dataset do not have handler, then do nothing.
+ pass
+
+
+class RollingGen(TaskGen):
+ ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
+ ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
+
+ def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
+ """
+ Generate tasks for rolling
+
+ Parameters
+ ----------
+ step : int
+ step to rolling
+ rtype : str
+ rolling type (expanding, sliding)
+ ds_extra_mod_func: Callable
+ A method like: handler_mod(task: dict, rg: RollingGen)
+ Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
+ """
+ self.step = step
+ self.rtype = rtype
+ self.ds_extra_mod_func = ds_extra_mod_func
+ self.ta = TimeAdjuster(future=True)
+
+ self.test_key = "test"
+ self.train_key = "train"
+
+ def generate(self, task: dict) -> List[dict]:
+ """
+ Converting the task into a rolling task.
+
+ Parameters
+ ----------
+ task: dict
+ A dict describing a task. For example.
+
+ .. code-block:: python
+
+ DEFAULT_TASK = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi100",
+ },
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
+ "record": [
+ {
+ "class": "SignalRecord",
+ "module_path": "qlib.workflow.record_temp",
+ },
+ ]
+ }
+
+ Returns
+ ----------
+ List[dict]: a list of tasks
+ """
+ res = []
+
+ prev_seg = None
+ test_end = None
+ while True:
+ t = copy.deepcopy(task)
+
+ # calculate segments
+ if prev_seg is None:
+ # First rolling
+ # 1) prepare the end point
+ segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
+ test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
+ # 2) and init test segments
+ test_start_idx = self.ta.align_idx(segments[self.test_key][0])
+ segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
+ else:
+ segments = {}
+ try:
+ for k, seg in prev_seg.items():
+ # decide how to shift
+ # expanding only for train data, the segments size of test data and valid data won't change
+ if k == self.train_key and self.rtype == self.ROLL_EX:
+ rtype = self.ta.SHIFT_EX
+ else:
+ rtype = self.ta.SHIFT_SD
+ # shift the segments data
+ segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
+ if segments[self.test_key][0] > test_end:
+ break
+ except KeyError:
+ # We reach the end of tasks
+ # No more rolling
+ break
+
+ # update segments of this task
+ t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
+ prev_seg = segments
+ if self.ds_extra_mod_func is not None:
+ self.ds_extra_mod_func(t, self)
+ res.append(t)
+ return res
diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py
new file mode 100644
index 0000000000..658eec4d6e
--- /dev/null
+++ b/qlib/workflow/task/manage.py
@@ -0,0 +1,493 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
+These features can run tasks concurrently and ensure every task will be used only once.
+Task Manager will store all tasks in `MongoDB `_.
+Users **MUST** finished the configuration of `MongoDB `_ when using this module.
+
+A task in TaskManager consists of 3 parts
+- tasks description: the desc will define the task
+- tasks status: the status of the task
+- tasks result: A user can get the task with the task description and task result.
+"""
+import concurrent
+import pickle
+import time
+from contextlib import contextmanager
+from typing import Callable, List
+
+import fire
+import pymongo
+from bson.binary import Binary
+from bson.objectid import ObjectId
+from pymongo.errors import InvalidDocument
+from qlib import auto_init, get_module_logger
+from tqdm.cli import tqdm
+
+from .utils import get_mongodb
+
+
+class TaskManager:
+ """
+ TaskManager
+
+ Here is what will a task looks like when it created by TaskManager
+
+ .. code-block:: python
+
+ {
+ 'def': pickle serialized task definition. using pickle will make it easier
+ 'filter': json-like data. This is for filtering the tasks.
+ 'status': 'waiting' | 'running' | 'done'
+ 'res': pickle serialized task result,
+ }
+
+ The tasks manager assumes that you will only update the tasks you fetched.
+ The mongo fetch one and update will make it date updating secure.
+
+ .. note::
+
+ Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
+
+ Here are four status which are:
+
+ STATUS_WAITING: waiting for training
+
+ STATUS_RUNNING: training
+
+ STATUS_PART_DONE: finished some step and waiting for next step
+
+ STATUS_DONE: all work done
+ """
+
+ STATUS_WAITING = "waiting"
+ STATUS_RUNNING = "running"
+ STATUS_DONE = "done"
+ STATUS_PART_DONE = "part_done"
+
+ ENCODE_FIELDS_PREFIX = ["def", "res"]
+
+ def __init__(self, task_pool: str = None):
+ """
+ Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
+
+ Parameters
+ ----------
+ task_pool: str
+ the name of Collection in MongoDB
+ """
+ self.mdb = get_mongodb()
+ if task_pool is not None:
+ self.task_pool = getattr(self.mdb, task_pool)
+ self.logger = get_module_logger(self.__class__.__name__)
+
+ def list(self) -> list:
+ """
+ List the all collection(task_pool) of the db
+
+ Returns:
+ list
+ """
+ return self.mdb.list_collection_names()
+
+ def _encode_task(self, task):
+ for prefix in self.ENCODE_FIELDS_PREFIX:
+ for k in list(task.keys()):
+ if k.startswith(prefix):
+ task[k] = Binary(pickle.dumps(task[k]))
+ return task
+
+ def _decode_task(self, task):
+ for prefix in self.ENCODE_FIELDS_PREFIX:
+ for k in list(task.keys()):
+ if k.startswith(prefix):
+ task[k] = pickle.loads(task[k])
+ return task
+
+ def _dict_to_str(self, flt):
+ return {k: str(v) for k, v in flt.items()}
+
+ def replace_task(self, task, new_task):
+ """
+ Use a new task to replace a old one
+
+ Args:
+ task: old task
+ new_task: new task
+ """
+ new_task = self._encode_task(new_task)
+ query = {"_id": ObjectId(task["_id"])}
+ try:
+ self.task_pool.replace_one(query, new_task)
+ except InvalidDocument:
+ task["filter"] = self._dict_to_str(task["filter"])
+ self.task_pool.replace_one(query, new_task)
+
+ def insert_task(self, task):
+ """
+ Insert a task.
+
+ Args:
+ task: the task waiting for insert
+
+ Returns:
+ pymongo.results.InsertOneResult
+ """
+ try:
+ insert_result = self.task_pool.insert_one(task)
+ except InvalidDocument:
+ task["filter"] = self._dict_to_str(task["filter"])
+ insert_result = self.task_pool.insert_one(task)
+ return insert_result
+
+ def insert_task_def(self, task_def):
+ """
+ Insert a task to task_pool
+
+ Parameters
+ ----------
+ task_def: dict
+ the task definition
+
+ Returns
+ -------
+ pymongo.results.InsertOneResult
+ """
+ task = self._encode_task(
+ {
+ "def": task_def,
+ "filter": task_def, # FIXME: catch the raised error
+ "status": self.STATUS_WAITING,
+ }
+ )
+ insert_result = self.insert_task(task)
+ return insert_result
+
+ def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
+ """
+ If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
+ If a task is not new, then just query its _id.
+
+ Parameters
+ ----------
+ task_def_l: list
+ a list of task
+ dry_run: bool
+ if insert those new tasks to task pool
+ print_nt: bool
+ if print new task
+
+ Returns
+ -------
+ List[str]
+ a list of the _id of task_def_l
+ """
+ new_tasks = []
+ _id_list = []
+ for t in task_def_l:
+ try:
+ r = self.task_pool.find_one({"filter": t})
+ except InvalidDocument:
+ r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
+ if r is None:
+ new_tasks.append(t)
+ if not dry_run:
+ insert_result = self.insert_task_def(t)
+ _id_list.append(insert_result.inserted_id)
+ else:
+ _id_list.append(None)
+ else:
+ _id_list.append(self._decode_task(r)["_id"])
+
+ self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
+
+ if print_nt: # print new task
+ for t in new_tasks:
+ print(t)
+
+ if dry_run:
+ return []
+
+ return _id_list
+
+ def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
+ """
+ Use query to fetch tasks.
+
+ Args:
+ query (dict, optional): query dict. Defaults to {}.
+ status (str, optional): [description]. Defaults to STATUS_WAITING.
+
+ Returns:
+ dict: a task(document in collection) after decoding
+ """
+ query = query.copy()
+ if "_id" in query:
+ query["_id"] = ObjectId(query["_id"])
+ query.update({"status": status})
+ task = self.task_pool.find_one_and_update(
+ query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
+ )
+ # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
+ if task is None:
+ return None
+ task["status"] = self.STATUS_RUNNING
+ return self._decode_task(task)
+
+ @contextmanager
+ def safe_fetch_task(self, query={}, status=STATUS_WAITING):
+ """
+ Fetch task from task_pool using query with contextmanager
+
+ Parameters
+ ----------
+ query: dict
+ the dict of query
+
+ Returns
+ -------
+ dict: a task(document in collection) after decoding
+ """
+ task = self.fetch_task(query=query, status=status)
+ try:
+ yield task
+ except Exception:
+ if task is not None:
+ self.logger.info("Returning task before raising error")
+ self.return_task(task)
+ self.logger.info("Task returned")
+ raise
+
+ def task_fetcher_iter(self, query={}):
+ while True:
+ with self.safe_fetch_task(query=query) as task:
+ if task is None:
+ break
+ yield task
+
+ def query(self, query={}, decode=True):
+ """
+ Query task in collection.
+ This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
+
+ Parameters
+ ----------
+ query: dict
+ the dict of query
+ decode: bool
+
+ Returns
+ -------
+ dict: a task(document in collection) after decoding
+ """
+ query = query.copy()
+ if "_id" in query:
+ query["_id"] = ObjectId(query["_id"])
+ for t in self.task_pool.find(query):
+ yield self._decode_task(t)
+
+ def re_query(self, _id):
+ """
+ Use _id to query task.
+
+ Args:
+ _id (str): _id of a document
+
+ Returns:
+ dict: a task(document in collection) after decoding
+ """
+ t = self.task_pool.find_one({"_id": ObjectId(_id)})
+ return self._decode_task(t)
+
+ def commit_task_res(self, task, res, status=STATUS_DONE):
+ """
+ Commit the result to task['res'].
+
+ Args:
+ task ([type]): [description]
+ res (object): the result you want to save
+ status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
+ """
+ # A workaround to use the class attribute.
+ if status is None:
+ status = TaskManager.STATUS_DONE
+ self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
+
+ def return_task(self, task, status=STATUS_WAITING):
+ """
+ Return a task to status. Alway using in error handling.
+
+ Args:
+ task ([type]): [description]
+ status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
+ """
+ if status is None:
+ status = TaskManager.STATUS_WAITING
+ update_dict = {"$set": {"status": status}}
+ self.task_pool.update_one({"_id": task["_id"]}, update_dict)
+
+ def remove(self, query={}):
+ """
+ Remove the task using query
+
+ Parameters
+ ----------
+ query: dict
+ the dict of query
+
+ """
+ query = query.copy()
+ if "_id" in query:
+ query["_id"] = ObjectId(query["_id"])
+ self.task_pool.delete_many(query)
+
+ def task_stat(self, query={}) -> dict:
+ """
+ Count the tasks in every status.
+
+ Args:
+ query (dict, optional): the query dict. Defaults to {}.
+
+ Returns:
+ dict
+ """
+ query = query.copy()
+ if "_id" in query:
+ query["_id"] = ObjectId(query["_id"])
+ tasks = self.query(query=query, decode=False)
+ status_stat = {}
+ for t in tasks:
+ status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
+ return status_stat
+
+ def reset_waiting(self, query={}):
+ """
+ Reset all running task into waiting status. Can be used when some running task exit unexpected.
+
+ Args:
+ query (dict, optional): the query dict. Defaults to {}.
+ """
+ query = query.copy()
+ # default query
+ if "status" not in query:
+ query["status"] = self.STATUS_RUNNING
+ return self.reset_status(query=query, status=self.STATUS_WAITING)
+
+ def reset_status(self, query, status):
+ query = query.copy()
+ if "_id" in query:
+ query["_id"] = ObjectId(query["_id"])
+ print(self.task_pool.update_many(query, {"$set": {"status": status}}))
+
+ def prioritize(self, task, priority: int):
+ """
+ Set priority for task
+
+ Parameters
+ ----------
+ task : dict
+ The task query from the database
+ priority : int
+ the target priority
+ """
+ update_dict = {"$set": {"priority": priority}}
+ self.task_pool.update_one({"_id": task["_id"]}, update_dict)
+
+ def _get_undone_n(self, task_stat):
+ return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
+
+ def _get_total(self, task_stat):
+ return sum(task_stat.values())
+
+ def wait(self, query={}):
+ task_stat = self.task_stat(query)
+ total = self._get_total(task_stat)
+ last_undone_n = self._get_undone_n(task_stat)
+ with tqdm(total=total, initial=total - last_undone_n) as pbar:
+ while True:
+ time.sleep(10)
+ undone_n = self._get_undone_n(self.task_stat(query))
+ pbar.update(last_undone_n - undone_n)
+ last_undone_n = undone_n
+ if undone_n == 0:
+ break
+
+ def __str__(self):
+ return f"TaskManager({self.task_pool})"
+
+
+def run_task(
+ task_func: Callable,
+ task_pool: str,
+ query: dict = {},
+ force_release: bool = False,
+ before_status: str = TaskManager.STATUS_WAITING,
+ after_status: str = TaskManager.STATUS_DONE,
+ **kwargs,
+):
+ """
+ While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
+
+ After running this method, here are 4 situations (before_status -> after_status):
+
+ STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
+
+ STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
+
+ STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
+
+ STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
+
+ Parameters
+ ----------
+ task_func : Callable
+ def (task_def, **kwargs) ->
+ the function to run the task
+ task_pool : str
+ the name of the task pool (Collection in MongoDB)
+ query: dict
+ will use this dict to query task_pool when fetching task
+ force_release : bool
+ will the program force to release the resource
+ before_status : str:
+ the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
+ after_status : str:
+ the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
+ kwargs
+ the params for `task_func`
+ """
+ tm = TaskManager(task_pool)
+
+ ever_run = False
+
+ while True:
+ with tm.safe_fetch_task(status=before_status, query=query) as task:
+ if task is None:
+ break
+ get_module_logger("run_task").info(task["def"])
+ # when fetching `WAITING` task, use task["def"] to train
+ if before_status == TaskManager.STATUS_WAITING:
+ param = task["def"]
+ # when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
+ elif before_status == TaskManager.STATUS_PART_DONE:
+ param = task["res"]
+ else:
+ raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
+ if force_release:
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
+ res = executor.submit(task_func, param, **kwargs).result()
+ else:
+ res = task_func(param, **kwargs)
+ tm.commit_task_res(task, res, status=after_status)
+ ever_run = True
+
+ return ever_run
+
+
+if __name__ == "__main__":
+ # This is for using it in cmd
+ # E.g. : `python -m qlib.workflow.task.manage list`
+ auto_init()
+ fire.Fire(TaskManager)
diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py
new file mode 100644
index 0000000000..174b4b9bfc
--- /dev/null
+++ b/qlib/workflow/task/utils.py
@@ -0,0 +1,258 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+"""
+Some tools for task management.
+"""
+
+import bisect
+import pandas as pd
+from qlib.data import D
+from qlib.workflow import R
+from qlib.config import C
+from qlib.log import get_module_logger
+from pymongo import MongoClient
+from pymongo.database import Database
+from typing import Union
+
+
+def get_mongodb() -> Database:
+
+ """
+ Get database in MongoDB, which means you need to declare the address and the name of a database at first.
+
+ For example:
+
+ Using qlib.init():
+
+ mongo_conf = {
+ "task_url": task_url, # your MongoDB url
+ "task_db_name": task_db_name, # database name
+ }
+ qlib.init(..., mongo=mongo_conf)
+
+ After qlib.init():
+
+ C["mongo"] = {
+ "task_url" : "mongodb://localhost:27017/",
+ "task_db_name" : "rolling_db"
+ }
+
+ Returns:
+ Database: the Database instance
+ """
+ try:
+ cfg = C["mongo"]
+ except KeyError:
+ get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
+ raise
+
+ client = MongoClient(cfg["task_url"])
+ return client.get_database(name=cfg["task_db_name"])
+
+
+def list_recorders(experiment, rec_filter_func=None):
+ """
+ List all recorders which can pass the filter in an experiment.
+
+ Args:
+ experiment (str or Experiment): the name of an Experiment or an instance
+ rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
+
+ Returns:
+ dict: a dict {rid: recorder} after filtering.
+ """
+ if isinstance(experiment, str):
+ experiment = R.get_exp(experiment_name=experiment)
+ recs = experiment.list_recorders()
+ recs_flt = {}
+ for rid, rec in recs.items():
+ if rec_filter_func is None or rec_filter_func(rec):
+ recs_flt[rid] = rec
+
+ return recs_flt
+
+
+class TimeAdjuster:
+ """
+ Find appropriate date and adjust date.
+ """
+
+ def __init__(self, future=True, end_time=None):
+ self._future = future
+ self.cals = D.calendar(future=future, end_time=end_time)
+
+ def set_end_time(self, end_time=None):
+ """
+ Set end time. None for use calendar's end time.
+
+ Args:
+ end_time
+ """
+ self.cals = D.calendar(future=self._future, end_time=end_time)
+
+ def get(self, idx: int):
+ """
+ Get datetime by index.
+
+ Parameters
+ ----------
+ idx : int
+ index of the calendar
+ """
+ if idx >= len(self.cals):
+ return None
+ return self.cals[idx]
+
+ def max(self) -> pd.Timestamp:
+ """
+ Return the max calendar datetime
+ """
+ return max(self.cals)
+
+ def align_idx(self, time_point, tp_type="start") -> int:
+ """
+ Align the index of time_point in the calendar.
+
+ Parameters
+ ----------
+ time_point
+ tp_type : str
+
+ Returns
+ -------
+ index : int
+ """
+ time_point = pd.Timestamp(time_point)
+ if tp_type == "start":
+ idx = bisect.bisect_left(self.cals, time_point)
+ elif tp_type == "end":
+ idx = bisect.bisect_right(self.cals, time_point) - 1
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ return idx
+
+ def cal_interval(self, time_point_A, time_point_B) -> int:
+ """
+ Calculate the trading day interval (time_point_A - time_point_B)
+
+ Args:
+ time_point_A : time_point_A
+ time_point_B : time_point_B (is the past of time_point_A)
+
+ Returns:
+ int: the interval between A and B
+ """
+ return self.align_idx(time_point_A) - self.align_idx(time_point_B)
+
+ def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
+ """
+ Align time_point to trade date of calendar
+
+ Args:
+ time_point
+ Time point
+ tp_type : str
+ time point type (`"start"`, `"end"`)
+
+ Returns:
+ pd.Timestamp
+ """
+ return self.cals[self.align_idx(time_point, tp_type=tp_type)]
+
+ def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
+ """
+ Align the given date to the trade date
+
+ for example:
+
+ .. code-block:: python
+
+ input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
+
+ output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
+ 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
+ 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
+
+ Parameters
+ ----------
+ segment
+
+ Returns
+ -------
+ Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
+ """
+ if isinstance(segment, dict):
+ return {k: self.align_seg(seg) for k, seg in segment.items()}
+ elif isinstance(segment, tuple) or isinstance(segment, list):
+ return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ def truncate(self, segment: tuple, test_start, days: int) -> tuple:
+ """
+ Truncate the segment based on the test_start date
+
+ Parameters
+ ----------
+ segment : tuple
+ time segment
+ test_start
+ days : int
+ The trading days to be truncated
+ the data in this segment may need 'days' data
+
+ Returns
+ ---------
+ tuple: new segment
+ """
+ test_idx = self.align_idx(test_start)
+ if isinstance(segment, tuple):
+ new_seg = []
+ for time_point in segment:
+ tp_idx = min(self.align_idx(time_point), test_idx - days)
+ assert tp_idx > 0
+ new_seg.append(self.get(tp_idx))
+ return tuple(new_seg)
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ SHIFT_SD = "sliding"
+ SHIFT_EX = "expanding"
+
+ def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
+ """
+ Shift the datatime of segment
+
+ Parameters
+ ----------
+ seg :
+ datetime segment
+ step : int
+ rolling step
+ rtype : str
+ rolling type ("sliding" or "expanding")
+
+ Returns
+ --------
+ tuple: new segment
+
+ Raises
+ ------
+ KeyError:
+ shift will raise error if the index(both start and end) is out of self.cal
+ """
+ if isinstance(seg, tuple):
+ start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
+ if rtype == self.SHIFT_SD:
+ start_idx += step
+ end_idx += step
+ elif rtype == self.SHIFT_EX:
+ end_idx += step
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ if start_idx > len(self.cals):
+ raise KeyError("The segment is out of valid calendar")
+ return self.get(start_idx), self.get(end_idx)
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
diff --git a/setup.py b/setup.py
index 747d885f4f..92c9ccc0cc 100644
--- a/setup.py
+++ b/setup.py
@@ -55,7 +55,9 @@
"tornado",
"joblib>=0.17.0",
"ruamel.yaml>=0.16.12",
+ "pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
+ "dill",
]
# Numpy include