diff --git a/docs/_static/img/online_serving.png b/docs/_static/img/online_serving.png new file mode 100644 index 0000000000..8647ebe773 Binary files /dev/null and b/docs/_static/img/online_serving.png differ diff --git a/docs/advanced/serial.rst b/docs/advanced/serial.rst index 8c0f837467..e0840069bf 100644 --- a/docs/advanced/serial.rst +++ b/docs/advanced/serial.rst @@ -14,6 +14,9 @@ Serializable Class ``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format. When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk. +However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature. + +Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here `_). Example ========================== diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst new file mode 100644 index 0000000000..56a3137f9f --- /dev/null +++ b/docs/advanced/task_management.rst @@ -0,0 +1,89 @@ +.. _task_management: + +================================= +Task Management +================================= +.. currentmodule:: qlib + + +Introduction +============= + +The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``. +To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_. +With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models. + +This whole process can be used in `Online Serving <../component/online.html>`_. + +An example of the entire process is shown `here `_. + +Task Generating +=============== +A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users. +The specific task template can be viewed in +`Task Section <../component/workflow.html#task-section>`_. +Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template. + +Here is the base class of ``TaskGen``: + +.. autoclass:: qlib.workflow.task.gen.TaskGen + :members: + +``Qlib`` provides a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments. +This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_. + +Task Storing +=============== +To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_. +``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling. +Users **MUST** finish the configuration of `MongoDB `_ when using this module. + +Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this. + + .. code-block:: python + + from qlib.config import C + C["mongo"] = { + "task_url" : "mongodb://localhost:27017/", # your MongoDB url + "task_db_name" : "rolling_db" # database name + } + +.. autoclass:: qlib.workflow.task.manage.TaskManager + :members: + +More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_. + +Task Training +=============== +After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status. +``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed. +An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly. +It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*. + +.. autofunction:: qlib.workflow.task.manage.run_task + +Meanwhile, ``Qlib`` provides a module called ``Trainer``. + +.. autoclass:: qlib.model.trainer.Trainer + :members: + +``Trainer`` will train a list of tasks and return a list of model recorders. +``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. +If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough. +`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``. + +Task Collecting +=============== +To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way. + +`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict). + +`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule). +For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object} + +`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble. +For example: {C1: object, C2: object} ---``Ensemble``---> object + +So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``. + +For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_. \ No newline at end of file diff --git a/docs/component/online.rst b/docs/component/online.rst new file mode 100644 index 0000000000..accc936dd4 --- /dev/null +++ b/docs/component/online.rst @@ -0,0 +1,46 @@ +.. _online: + +================================= +Online Serving +================================= +.. currentmodule:: qlib + + +Introduction +============= + +.. image:: ../_static/img/online_serving.png + :align: center + + +In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions. +``Online Serving`` is a set of modules for online models using the latest data, +which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_. + +`Here `_ are several examples for reference, which demonstrate different features of ``Online Serving``. +If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_. +The `examples `_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``. + +Online Manager +============= + +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +============= + +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +============= + +.. automodule:: qlib.workflow.online.utils + :members: + +Updater +============= + +.. automodule:: qlib.workflow.online.update + :members: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 3fa35fc60d..803aa97d2d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,7 @@ Document Structure Intraday Trading: Model&Strategy Testing Qlib Recorder: Experiment Management Analysis: Evaluation & Results Analysis + Online Serving: Online Management & Strategy & Tool .. toctree:: :maxdepth: 3 @@ -50,6 +51,7 @@ Document Structure Building Formulaic Alphas Online & Offline mode Serialization + Task Management .. toctree:: :maxdepth: 3 diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 3167d8a622..57f61f18b1 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -154,6 +154,70 @@ Record Template .. automodule:: qlib.workflow.record_temp :members: +Task Management +==================== + + +TaskGen +-------------------- +.. automodule:: qlib.workflow.task.gen + :members: + +TaskManager +-------------------- +.. automodule:: qlib.workflow.task.manage + :members: + +Trainer +-------------------- +.. automodule:: qlib.model.trainer + :members: + +Collector +-------------------- +.. automodule:: qlib.workflow.task.collect + :members: + +Group +-------------------- +.. automodule:: qlib.model.ens.group + :members: + +Ensemble +-------------------- +.. automodule:: qlib.model.ens.ensemble + :members: + +Utils +-------------------- +.. automodule:: qlib.workflow.task.utils + :members: + + +Online Serving +==================== + + +Online Manager +-------------------- +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +-------------------- +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +-------------------- +.. automodule:: qlib.workflow.online.utils + :members: + +RecordUpdater +-------------------- +.. automodule:: qlib.workflow.online.update + :members: + Utils ==================== @@ -162,4 +226,7 @@ Serializable -------------------- .. automodule:: qlib.utils.serial.Serializable - :members: \ No newline at end of file + :members: + + + \ No newline at end of file diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst index 15aa957d1a..32c17ff837 100644 --- a/docs/start/initialization.rst +++ b/docs/start/initialization.rst @@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo "default_exp_name": "Experiment", } }) +- `mongo` + Type: dict, optional parameter, the setting of `MongoDB `_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing. + Users need finished `installation `_ firstly, and run it in a fixed URL. + + .. code-block:: Python + + # For example, you can initialize qlib below + qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={ + "task_url": "mongodb://localhost:27017/", # your mongo url + "task_db_name": "rolling_db", # the database name of Task Management + }) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py new file mode 100644 index 0000000000..4f3ac04b15 --- /dev/null +++ b/examples/model_rolling/task_manager_rolling.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example shows how a TrainerRM works based on TaskManager with rolling tasks. +After training, how to collect the rolling results will be shown in task_collecting. +""" + +from pprint import pprint + +import fire +import qlib +from qlib.config import REG_CN +from qlib.workflow import R +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager +from qlib.workflow.task.collect import RecorderCollector +from qlib.model.ens.group import RollingGroup +from qlib.model.trainer import TrainerRM + + +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi100", +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, +} + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + +# use lgb +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} + + +class RollingTaskExample: + def __init__( + self, + provider_uri="~/.qlib/qlib_data/cn_data", + region=REG_CN, + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + experiment_name="rolling_exp", + task_pool="rolling_task", + task_config=[task_xgboost_config, task_lgb_config], + rolling_step=550, + rolling_type=RollingGen.ROLL_SD, + ): + # TaskManager config + mongo_conf = { + "task_url": task_url, + "task_db_name": task_db_name, + } + qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) + self.experiment_name = experiment_name + self.task_pool = task_pool + self.task_config = task_config + self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type) + + # Reset all things to the first status, be careful to save important data + def reset(self): + print("========== reset ==========") + TaskManager(task_pool=self.task_pool).remove() + exp = R.get_exp(experiment_name=self.experiment_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + + def task_generating(self): + print("========== task_generating ==========") + tasks = task_generator( + tasks=self.task_config, + generators=self.rolling_gen, # generate different date segments + ) + pprint(tasks) + return tasks + + def task_training(self, tasks): + print("========== task_training ==========") + trainer = TrainerRM(self.experiment_name, self.task_pool) + trainer.train(tasks) + + def task_collecting(self): + print("========== task_collecting ==========") + + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key + + def my_filter(recorder): + # only choose the results of "LGBModel" + model_key, rolling_key = rec_key(recorder) + if model_key == "LGBModel": + return True + return False + + collector = RecorderCollector( + experiment=self.experiment_name, + process_list=RollingGroup(), + rec_key_func=rec_key, + rec_filter_func=my_filter, + ) + print(collector()) + + def main(self): + self.reset() + tasks = self.task_generating() + self.task_training(tasks) + self.task_collecting() + + +if __name__ == "__main__": + ## to see the whole process with your own parameters, use the command below + # python task_manager_rolling.py main --experiment_name="your_exp_name" + fire.Fire(RollingTaskExample) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py new file mode 100644 index 0000000000..4bb5022ee0 --- /dev/null +++ b/examples/online_srv/online_management_simulate.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example is about how can simulate the OnlineManager based on rolling tasks. +""" + +import fire +import qlib +from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM +from qlib.workflow import R +from qlib.workflow.online.manager import OnlineManager +from qlib.workflow.online.strategy import RollingStrategy +from qlib.workflow.task.gen import RollingGen +from qlib.workflow.task.manage import TaskManager + + +data_handler_config = { + "start_time": "2018-01-01", + "end_time": "2018-10-31", + "fit_start_time": "2018-01-01", + "fit_end_time": "2018-03-31", + "instruments": "csi100", +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2018-01-01", "2018-03-31"), + "valid": ("2018-04-01", "2018-05-31"), + "test": ("2018-06-01", "2018-09-10"), + }, + }, +} + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + +# use lgb model +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost model +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} + + +class OnlineSimulationExample: + def __init__( + self, + provider_uri="~/.qlib/qlib_data/cn_data", + region="cn", + exp_name="rolling_exp", + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + task_pool="rolling_task", + rolling_step=80, + start_time="2018-09-10", + end_time="2018-10-31", + tasks=[task_xgboost_config, task_lgb_config], + ): + """ + Init OnlineManagerExample. + + Args: + provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data". + region (str, optional): the stock region. Defaults to "cn". + exp_name (str, optional): the experiment name. Defaults to "rolling_exp". + task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/". + task_db_name (str, optional): database name. Defaults to "rolling_db". + task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task". + rolling_step (int, optional): the step for rolling. Defaults to 80. + start_time (str, optional): the start time of simulating. Defaults to "2018-09-10". + end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". + tasks (dict or list[dict]): a set of the task config waiting for rolling and training + """ + self.exp_name = exp_name + self.task_pool = task_pool + self.start_time = start_time + self.end_time = end_time + mongo_conf = { + "task_url": task_url, + "task_db_name": task_db_name, + } + qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) + self.rolling_gen = RollingGen( + step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None + ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. + self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR + self.rolling_online_manager = OnlineManager( + RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), + trainer=self.trainer, + begin_time=self.start_time, + ) + self.tasks = tasks + + # Reset all things to the first status, be careful to save important data + def reset(self): + TaskManager(self.task_pool).remove() + exp = R.get_exp(experiment_name=self.exp_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + + # Run this to run all workflow automatically + def main(self): + print("========== reset ==========") + self.reset() + print("========== simulate ==========") + self.rolling_online_manager.simulate(end_time=self.end_time) + print("========== collect results ==========") + print(self.rolling_online_manager.get_collector()()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) + + +if __name__ == "__main__": + ## to run all workflow automatically with your own parameters, use the command below + # python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60 + fire.Fire(OnlineSimulationExample) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py new file mode 100644 index 0000000000..25b8b2a0c0 --- /dev/null +++ b/examples/online_srv/rolling_online_management.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example shows how OnlineManager works with rolling tasks. +There are four parts including first train, routine 1, add strategy and routine 2. +Firstly, the OnlineManager will finish the first training and set trained models to `online` models. +Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals +Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies. +Finally, the OnlineManager will finish second routine and update all strategies. +""" + +import os +import fire +import qlib +from qlib.workflow import R +from qlib.workflow.online.strategy import RollingStrategy +from qlib.workflow.task.gen import RollingGen +from qlib.workflow.online.manager import OnlineManager + +data_handler_config = { + "start_time": "2013-01-01", + "end_time": "2020-09-25", + "fit_start_time": "2013-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi100", +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2013-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2015-12-31"), + "test": ("2016-01-01", "2020-07-10"), + }, + }, +} + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + +# use lgb model +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost model +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} + + +class RollingOnlineExample: + def __init__( + self, + provider_uri="~/.qlib/qlib_data/cn_data", + region="cn", + task_url="mongodb://10.0.0.4:27017/", + task_db_name="rolling_db", + rolling_step=550, + tasks=[task_xgboost_config], + add_tasks=[task_lgb_config], + ): + mongo_conf = { + "task_url": task_url, # your MongoDB url + "task_db_name": task_db_name, # database name + } + qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) + self.tasks = tasks + self.add_tasks = add_tasks + self.rolling_step = rolling_step + strategies = [] + for task in tasks: + name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + strategies.append( + RollingStrategy( + name_id, + task, + RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), + ) + ) + + self.rolling_online_manager = OnlineManager(strategies) + + _ROLLING_MANAGER_PATH = ( + ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. + ) + + # Reset all things to the first status, be careful to save important data + def reset(self): + for task in self.tasks + self.add_tasks: + name_id = task["model"]["class"] + exp = R.get_exp(experiment_name=name_id) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + + if os.path.exists(self._ROLLING_MANAGER_PATH): + os.remove(self._ROLLING_MANAGER_PATH) + + def first_run(self): + print("========== reset ==========") + self.reset() + print("========== first_run ==========") + self.rolling_online_manager.first_train() + print("========== collect results ==========") + print(self.rolling_online_manager.get_collector()()) + print("========== dump ==========") + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) + + def routine(self): + print("========== load ==========") + self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) + print("========== routine ==========") + self.rolling_online_manager.routine() + print("========== collect results ==========") + print(self.rolling_online_manager.get_collector()()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) + print("========== dump ==========") + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) + + def add_strategy(self): + print("========== load ==========") + self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH) + print("========== add strategy ==========") + strategies = [] + for task in self.add_tasks: + name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + strategies.append( + RollingStrategy( + name_id, + task, + RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD), + ) + ) + self.rolling_online_manager.add_strategy(strategies=strategies) + print("========== dump ==========") + self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) + + def main(self): + self.first_run() + self.routine() + self.add_strategy() + self.routine() + + +if __name__ == "__main__": + ####### to train the first version's models, use the command below + # python rolling_online_management.py first_run + + ####### to update the models and predictions after the trading time, use the command below + # python rolling_online_management.py routine + + ####### to define your own parameters, use `--` + # python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40 + fire.Fire(RollingOnlineExample) diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py new file mode 100644 index 0000000000..228bc0dacb --- /dev/null +++ b/examples/online_srv/update_online_pred.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example shows how OnlineTool works when we need update prediction. +There are two parts including first_train and update_online_pred. +Firstly, we will finish the training and set the trained models to the `online` models. +Next, we will finish updating online predictions. +""" +import fire +import qlib +from qlib.config import REG_CN +from qlib.model.trainer import task_train +from qlib.workflow.online.utils import OnlineToolR + +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi100", +} + +task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + "record": { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, +} + + +class UpdatePredExample: + def __init__( + self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task + ): + qlib.init(provider_uri=provider_uri, region=region) + self.experiment_name = experiment_name + self.online_tool = OnlineToolR(self.experiment_name) + self.task_config = task_config + + def first_train(self): + rec = task_train(self.task_config, experiment_name=self.experiment_name) + self.online_tool.reset_online_tag(rec) # set to online model + + def update_online_pred(self): + self.online_tool.update_online_pred() + + def main(self): + self.first_train() + self.update_online_pred() + + +if __name__ == "__main__": + ## to train a model and set it to online model, use the command below + # python update_online_pred.py first_train + ## to update online predictions once a day, use the command below + # python update_online_pred.py update_online_pred + ## to see the whole process with your own parameters, use the command below + # python update_online_pred.py main --experiment_name="your_exp_name" + fire.Fire(UpdatePredExample) diff --git a/qlib/__init__.py b/qlib/__init__.py index 99035e5014..4fd48f8c19 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -3,6 +3,7 @@ __version__ = "0.6.3.99" +__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version import os @@ -10,12 +11,13 @@ import logging import platform import subprocess +from pathlib import Path +from .log import get_module_logger # init qlib def init(default_conf="client", **kwargs): from .config import C - from .log import get_module_logger from .data.cache import H H.clear() @@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs): def _mount_nfs_uri(C): - from .log import get_module_logger LOG = get_module_logger("mount nfs", level=logging.INFO) @@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs): config.update(kwargs) default_conf = config.pop("default_conf", "client") init(default_conf, **config) + + +def get_project_path(config_name="config.yaml", cur_path=None) -> Path: + """ + If users are building a project follow the following pattern. + - Qlib is a sub folder in project path + - There is a file named `config.yaml` in qlib. + + For example: + If your project file system stucuture follows such a pattern + + / + - config.yaml + - ...some folders... + - qlib/ + + This folder will return + + NOTE: link is not supported here. + + + This method is often used when + - user want to use a relative config path instead of hard-coding qlib config path in code + + Raises + ------ + FileNotFoundError: + If project path is not found + """ + if cur_path is None: + cur_path = Path(__file__).absolute().resolve() + while True: + if (cur_path / config_name).exists(): + return cur_path + if cur_path == cur_path.parent: + raise FileNotFoundError("We can't find the project path") + cur_path = cur_path.parent + + +def auto_init(**kwargs): + """ + This function will init qlib automatically with following priority + - Find the project configuration and init qlib + - The parsing process will be affected by the `conf_type` of the configuration file + - Init qlib with default config + """ + + try: + pp = get_project_path(cur_path=kwargs.pop("cur_path", None)) + except FileNotFoundError: + init(**kwargs) + else: + + conf_pp = pp / "config.yaml" + with conf_pp.open() as f: + conf = yaml.safe_load(f) + + conf_type = conf.get("conf_type", "origin") + if conf_type == "origin": + # The type of config is just like original qlib config + init_from_yaml_conf(conf_pp, **kwargs) + elif conf_type == "ref": + # This config type will be more convenient in following scenario + # - There is a shared configure file and you don't want to edit it inplace. + # - The shared configure may be updated later and you don't want to copy it. + # - You have some customized config. + qlib_conf_path = conf["qlib_cfg"] + qlib_conf_update = conf.get("qlib_cfg_update") + init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs) + logger = get_module_logger("Initialization") + logger.info(f"Auto load project config: {conf_pp}") diff --git a/qlib/config.py b/qlib/config.py index 75ab0fa3e8..4dedf75d06 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -33,6 +33,9 @@ def __getattr__(self, attr): raise AttributeError(f"No such {attr} in self._config") + def get(self, key, default=None): + return self.__dict__["_config"].get(key, default) + def __setitem__(self, key, value): self.__dict__["_config"][key] = value @@ -131,7 +134,7 @@ def set_conf_from_C(self, config_c): }, "loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}}, }, - # Defatult config for experiment manager + # Default config for experiment manager "exp_manager": { "class": "MLflowExpManager", "module_path": "qlib.workflow.expm", @@ -140,6 +143,11 @@ def set_conf_from_C(self, config_c): "default_exp_name": "Experiment", }, }, + # Default config for MongoDB + "mongo": { + "task_url": "mongodb://localhost:27017/", + "task_db_name": "default_task_db", + }, } MODE_CONF = { @@ -310,8 +318,22 @@ def register(self): # clean up experiment when python program ends experiment_exit_handler() + # Supporting user reset qlib version (useful when user want to connect to qlib server with old version) + self.reset_qlib_version() + self._registered = True + def reset_qlib_version(self): + import qlib + + reset_version = self.get("qlib_reset_version", None) + if reset_version is not None: + qlib.__version__ = reset_version + else: + qlib.__version__ = getattr(qlib, "__version__bak") + # Due to a bug? that converting __version__ to _QlibConfig__version__bak + # Using __version__bak instead of __version__ + @property def registered(self): return self._registered diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 970b032d6b..be2016ea32 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time): "fit_end_time": fit_end_time, } ) + # FIXME: the `module_path` parameter is missed. new_l.append({"class": klass.__name__, "kwargs": pkwargs}) else: new_l.append(p) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index b3eaac7a33..206561aed9 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -27,7 +27,7 @@ def __init__(self, **kwargs): - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - The data could specify the info to caculate the essential data for preparation + The data could specify the info to calculate the essential data for preparation """ self.setup_data(**kwargs) super().__init__() @@ -92,7 +92,7 @@ def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple handler : Union[dict, DataHandler] handler could be: - - insntance of `DataHandler` + - instance of `DataHandler` - config of `DataHandler`. Please refer to `DataHandler` @@ -112,8 +112,9 @@ def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple 'outsample': ("2017-01-01", "2020-08-01",), } """ - self.handler = init_instance_by_config(handler, accept_types=DataHandler) + self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + self.fetch_kwargs = {} super().__init__(**kwargs) def config(self, handler_kwargs: dict = None, **kwargs): @@ -123,7 +124,7 @@ def config(self, handler_kwargs: dict = None, **kwargs): Parameters ---------- handler_kwargs : dict - Config of DataHanlder, which could include the following arguments: + Config of DataHandler, which could include the following arguments: - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. @@ -147,11 +148,11 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs): Parameters ---------- handler_kwargs : dict - init arguments of DataHanlder, which could include the following arguments: + init arguments of DataHandler, which could include the following arguments: - init_type : Init Type of Handler - - enable_cache : wheter to enable cache + - enable_cache : whether to enable cache """ super().setup_data(**kwargs) @@ -171,7 +172,10 @@ def _prepare_seg(self, slc: slice, **kwargs): ---------- slc : slice """ - return self.handler.fetch(slc, **kwargs) + if hasattr(self, "fetch_kwargs"): + return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) + else: + return self.handler.fetch(slc, **kwargs) def prepare( self, @@ -199,6 +203,12 @@ def prepare( The data to fetch: DK_* Default is DK_I, which indicate fetching data for **inference**. + kwargs : + The parameters that kwargs may contain: + flt_col : str + It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data. + This parameter is only supported when it is an instance of TSDatasetH. + Returns ------- Union[List[pd.DataFrame], pd.DataFrame]: @@ -231,7 +241,7 @@ class TSDataSampler: (T)ime-(S)eries DataSampler This is the result of TSDatasetH - It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series + It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series dataset based on tabular data. If user have further requirements for processing data, user could process them based on `TSDataSampler` or create @@ -243,7 +253,9 @@ class TSDataSampler: """ - def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"): + def __init__( + self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None + ): """ Build a dataset which looks like torch.data.utils.Dataset. @@ -265,6 +277,11 @@ def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: s ffill with previous sample ffill+bfill: ffill with previous samples first and fill with later samples second + flt_data : pd.Series + a column of data(True or False) to filter data. + None: + kepp all data + """ self.start = start self.end = end @@ -272,23 +289,51 @@ def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: s self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) - self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! - # NOTE: append last line with full NaN for better performance in `__getitem__` - self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0) + + kwargs = {"object": self.data} + if dtype is not None: + kwargs["dtype"] = dtype + + self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values! + # NOTE: + # - append last line with full NaN for better performance in `__getitem__` + # - Keep the same dtype will result in a better performance + self.data_arr = np.append( + self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0 + ) self.nan_idx = -1 # The last line is all NaN # the data type will be changed # The index of usable data is between start_idx and end_idx - self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_df, self.idx_map = self.build_index(self.data) + self.data_index = deepcopy(self.data.index) + + if flt_data is not None: + self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1) + self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) + self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + + self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + del self.data # save memory + + @staticmethod + def flt_idx_map(flt_data, idx_map): + idx = 0 + new_idx_map = {} + for i, exist in enumerate(flt_data): + if exist: + new_idx_map[idx] = idx_map[i] + idx += 1 + return new_idx_map + def get_index(self): """ Get the pandas index of the data, it will be useful in following scenarios - Special sampler will be used (e.g. user want to sample day by day) """ - return self.data.index[self.start_idx : self.end_idx] + return self.data_index[self.start_idx : self.end_idx] def config(self, **kwargs): # Config the attributes @@ -432,7 +477,7 @@ class TSDatasetH(DatasetH): (T)ime-(S)eries Dataset (H)andler - Covnert the tabular data to Time-Series data + Convert the tabular data to Time-Series data Requirements analysis @@ -461,7 +506,7 @@ def setup_data(self, **kwargs): cal = sorted(cal) self.cal = cal - def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame: # Dataset decide how to slice data(Get more data for timeseries). start, end = slc.start, slc.stop start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) @@ -470,6 +515,25 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: # TSDatasetH will retrieve more data for complete data = super()._prepare_seg(slice(pad_start, end), **kwargs) + return data + + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + """ + split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data + """ + dtype = kwargs.pop("dtype", None) + start, end = slc.start, slc.stop + flt_col = kwargs.pop("flt_col", None) + # TSDatasetH will retrieve more data for complete + data = self._prepare_raw_seg(slc, **kwargs) + + flt_kwargs = deepcopy(kwargs) + if flt_col is not None: + flt_kwargs["col_set"] = flt_col + flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + assert len(flt_data.columns) == 1 + else: + flt_data = None - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len) + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) return tsds diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index a6ca658d16..c6338832a5 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -7,7 +7,7 @@ import logging import warnings from inspect import getfullargspec -from typing import Union, Tuple, List, Iterator, Optional +from typing import Callable, Union, Tuple, List, Iterator, Optional import pandas as pd import numpy as np @@ -36,7 +36,7 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - Any order of the index level can be suported (The order will be implied in the data). + Any order of the index level can be supported (The order will be implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: @@ -51,6 +51,9 @@ class DataHandler(Serializable): SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + + Tips for improving the performance of datahandler + - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc` """ def __init__( @@ -74,7 +77,7 @@ def __init__( data_loader : Union[dict, str, DataLoader] data loader to load the data. init_data : - intialize the original data in the constructor. + initialize the original data in the constructor. fetch_orig : bool Return the original data instead of copy if possible. """ @@ -125,7 +128,7 @@ def config(self, **kwargs): def setup_data(self, enable_cache: bool = False): """ - Set Up the data in case of running intialization for multiple time + Set Up the data in case of running initialization for multiple time It is responsible for maintaining following variable 1) self._data @@ -163,6 +166,7 @@ def fetch( level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = CS_ALL, squeeze: bool = False, + proc_func: Callable = None, ) -> pd.DataFrame: """ fetch data from underlying data source @@ -185,6 +189,14 @@ def fetch( - if isinstance(col_set, List[str]): select several sets of meaningful columns, the returned data has multiple levels + proc_func: Callable + - Give a hook for processing data before fetching + - An example to explain the necessity of the hook: + - A Dataset learned some processors to process data which is related to data segmentation + - It will apply them every time when preparing data. + - The learned processor require the dataframe remains the same format when fitting and applying + - However the data format will change according to the parameters. + - So the processors should be applied to the underlayer data. squeeze : bool whether squeeze columns and index @@ -193,8 +205,15 @@ def fetch( ------- pd.DataFrame. """ + if proc_func is None: + df = self._data + else: + # FIXME: fetching by time first will be more friendly to `proc_func` + # Copy in case of `proc_func` changing the data inplace.... + df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) + # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(self._data, col_set) + df = self._fetch_df_by_col(df, col_set) df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns @@ -261,6 +280,10 @@ def get_range_iterator( class DataHandlerLP(DataHandler): """ DataHandler with **(L)earnable (P)rocessor** + + Tips to improving the performance of data handler + - To reduce the memory cost + - `drop_raw=True`: this will modify the data inplace on raw data; """ # data key @@ -430,7 +453,7 @@ def config(self, processor_kwargs: dict = None, **kwargs): def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Set up the data in case of running intialization for multiple time + Set up the data in case of running initialization for multiple time Parameters ---------- @@ -474,6 +497,7 @@ def fetch( level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, data_key: str = DK_I, + proc_func: Callable = None, ) -> pd.DataFrame: """ fetch data from underlying data source @@ -488,12 +512,18 @@ def fetch( select a set of meaningful columns.(e.g. features, columns). data_key : str the data to fetch: DK_*. + proc_func: Callable + please refer to the doc of DataHandler.fetch Returns ------- pd.DataFrame: """ df = self._get_df_by_key(data_key) + if proc_func is not None: + # FIXME: fetch by time first will be more friendly to proc_func + # Copy incase of `proc_func` changing the data inplace.... + df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame df = self._fetch_df_by_col(df, col_set) return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 58aca1d4f7..2ad110b89d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -13,6 +13,7 @@ from qlib.data import filter as filter_module from qlib.data.filter import BaseDFilter from qlib.utils import load_dataset, init_instance_by_config +from qlib.log import get_module_logger class DataLoader(abc.ABC): @@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader): DataLoader based on (D)ata (H)andler It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler + + TODO: What make this module not that easy to use. + - For online scenario + - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook. """ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): @@ -265,7 +270,7 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: - LOG.warning(f"instruments[{instruments}] is ignored") + get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored") if self.is_group: df = pd.concat( diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 7635a4127b..fce22ddfcf 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import abc +from typing import Union, Text import numpy as np import pandas as pd import copy @@ -14,7 +15,7 @@ EPS = 1e-12 -def get_group_columns(df: pd.DataFrame, group: str): +def get_group_columns(df: pd.DataFrame, group: Union[Text, None]): """ get a group of columns from multi-index columns DataFrame diff --git a/qlib/model/ens/__init__.py b/qlib/model/ens/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py new file mode 100644 index 0000000000..4fa6a5ec63 --- /dev/null +++ b/qlib/model/ens/ensemble.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction. +""" + +from typing import Union +import pandas as pd +from qlib.utils import FLATTEN_TUPLE, flatten_dict + + +class Ensemble: + """Merge the ensemble_dict into an ensemble object. + + For example: {Rollinga_b: object, Rollingb_c: object} -> object + + When calling this class: + + Args: + ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging + + Returns: + object: the ensemble object + """ + + def __call__(self, ensemble_dict: dict, *args, **kwargs): + raise NotImplementedError(f"Please implement the `__call__` method.") + + +class SingleKeyEnsemble(Ensemble): + + """ + Extract the object if there is only one key and value in the dict. Make the result more readable. + {Only key: Only value} -> Only value + + If there is more than 1 key or less than 1 key, then do nothing. + Even you can run this recursively to make dict more readable. + + NOTE: Default runs recursively. + + When calling this class: + + Args: + ensemble_dict (dict): the dict. The key of the dict will be ignored. + + Returns: + dict: the readable dict. + """ + + def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object: + if not isinstance(ensemble_dict, dict): + return ensemble_dict + if recursion: + tmp_dict = {} + for k, v in ensemble_dict.items(): + tmp_dict[k] = self(v, recursion) + ensemble_dict = tmp_dict + keys = list(ensemble_dict.keys()) + if len(keys) == 1: + ensemble_dict = ensemble_dict[keys[0]] + return ensemble_dict + + +class RollingEnsemble(Ensemble): + + """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. + + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". + + When calling this class: + + Args: + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. + The key of the dict will be ignored. + + Returns: + pd.DataFrame: the complete result of rolling. + """ + + def __call__(self, ensemble_dict: dict) -> pd.DataFrame: + artifact_list = list(ensemble_dict.values()) + artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) + artifact = pd.concat(artifact_list) + # If there are duplicated predition, use the latest perdiction + artifact = artifact[~artifact.index.duplicated(keep="last")] + artifact = artifact.sort_index() + return artifact + + +class AverageEnsemble(Ensemble): + """ + Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble. + + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it. + + When calling this class: + + Args: + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. + The key of the dict will be ignored. + + Returns: + pd.DataFrame: the complete result of averaging and standardizing. + """ + + def __call__(self, ensemble_dict: dict) -> pd.DataFrame: + # need to flatten the nested dict + ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE) + values = list(ensemble_dict.values()) + results = pd.concat(values, axis=1) + results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std()) + results = results.mean(axis=1) + results = results.sort_index() + return results diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py new file mode 100644 index 0000000000..7f45b06a5c --- /dev/null +++ b/qlib/model/ens/group.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Group can group a set of objects based on `group_func` and change them to a dict. +After group, we provide a method to reduce them. + +For example: + +group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} +reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object} + +""" + +from qlib.model.ens.ensemble import Ensemble, RollingEnsemble +from typing import Callable, Union +from joblib import Parallel, delayed + + +class Group: + """Group the objects based on dict""" + + def __init__(self, group_func=None, ens: Ensemble = None): + """ + Init Group. + + Args: + group_func (Callable, optional): Given a dict and return the group key and one of the group elements. + + For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} + + Defaults to None. + + ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping. + """ + self._group_func = group_func + self._ens_func = ens + + def group(self, *args, **kwargs) -> dict: + """ + Group a set of objects and change them to a dict. + + For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} + + Returns: + dict: grouped dict + """ + if isinstance(getattr(self, "_group_func", None), Callable): + return self._group_func(*args, **kwargs) + else: + raise NotImplementedError(f"Please specify valid `group_func`.") + + def reduce(self, *args, **kwargs) -> dict: + """ + Reduce grouped dict. + + For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object} + + Returns: + dict: reduced dict + """ + if isinstance(getattr(self, "_ens_func", None), Callable): + return self._ens_func(*args, **kwargs) + else: + raise NotImplementedError(f"Please specify valid `_ens_func`.") + + def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict: + """ + Group the ungrouped_dict into different groups. + + Args: + ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things} + + Returns: + dict: grouped_dict like {G1: object, G2: object} + n_jobs: how many progress you need. + verbose: the print mode for Parallel. + """ + + # NOTE: The multiprocessing will raise error if you use `Serializable` + # Because the `Serializable` will affect the behaviors of pickle + grouped_dict = self.group(ungrouped_dict, *args, **kwargs) + + key_l = [] + job_l = [] + for key, value in grouped_dict.items(): + key_l.append(key) + job_l.append(delayed(Group.reduce)(self, value)) + return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l))) + + +class RollingGroup(Group): + """Group the rolling dict""" + + def group(self, rolling_dict: dict) -> dict: + """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}} + + NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly. + + Args: + rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing. + + Returns: + dict: grouped dict + """ + grouped_dict = {} + for key, values in rolling_dict.items(): + if isinstance(key, tuple): + grouped_dict.setdefault(key[:-1], {})[key[-1]] = values + return grouped_dict + + def __init__(self): + super().__init__(ens=RollingEnsemble()) diff --git a/qlib/model/task.py b/qlib/model/task.py deleted file mode 100644 index f29f513a4e..0000000000 --- a/qlib/model/task.py +++ /dev/null @@ -1,27 +0,0 @@ -import abc -import typing - - -class TaskGen(metaclass=abc.ABCMeta): - @abc.abstractmethod - def __call__(self, *args, **kwargs) -> typing.List[dict]: - """ - generate - - Parameters - ---------- - args, kwargs: - The info for generating tasks - Example 1): - input: a specific task template - output: rolling version of the tasks - Example 2): - input: a specific task template - output: a set of tasks with different losses - - Returns - ------- - typing.List[dict]: - A list of tasks - """ - pass diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index f0bc0b780a..fd76e67284 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,42 +1,446 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.utils import init_instance_by_config, flatten_dict +""" +The Trainer will train a list of tasks and return a list of model recorders. +There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder). + +This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training. +In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. + +``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically. +""" + +import socket +from typing import Callable, List + +from qlib.data.dataset import Dataset +from qlib.model.base import Model +from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.manage import TaskManager, run_task -def task_train(task_config: dict, experiment_name): +def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder: """ - task based training + Begin task training to start a recorder and save the task config. - Parameters - ---------- - task_config : dict - A dict describes a task setting. + Args: + task_config (dict): the config of a task + experiment_name (str): the name of experiment + recorder_name (str): the given name will be the recorder name. None for using rid. + + Returns: + Recorder: the model recorder """ + with R.start(experiment_name=experiment_name, recorder_name=recorder_name): + R.log_params(**flatten_dict(task_config)) + R.save_objects(**{"task": task_config}) # keep the original format and datatype + R.set_tags(**{"hostname": socket.gethostname()}) + recorder: Recorder = R.get_recorder() + return recorder - # model initiaiton - model = init_instance_by_config(task_config["model"]) - dataset = init_instance_by_config(task_config["dataset"]) - # start exp - with R.start(experiment_name=experiment_name): - # train model - R.log_params(**flatten_dict(task_config)) +def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: + """ + Finish task training with real model fitting and saving. + + Args: + rec (Recorder): the recorder will be resumed + experiment_name (str): the name of experiment + + Returns: + Recorder: the model recorder + """ + with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True): + task_config = R.load_object("task") + # model & dataset initiation + model: Model = init_instance_by_config(task_config["model"]) + dataset: Dataset = init_instance_by_config(task_config["dataset"]) + # model training model.fit(dataset) - recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) - + # this dataset is saved for online inference. So the concrete data should not be dumped + dataset.config(dump_all=False, recursive=True) + R.save_objects(**{"dataset": dataset}) # generate records: prediction, backtest, and analysis - for record in task_config["record"]: - if record["class"] == SignalRecord.__name__: - srconf = {"model": model, "dataset": dataset, "recorder": recorder} - record["kwargs"].update(srconf) - sr = init_instance_by_config(record) - sr.generate() + records = task_config.get("record", []) + if isinstance(records, dict): # prevent only one dict + records = [records] + for record in records: + cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") + if cls is SignalRecord: + rconf = {"model": model, "dataset": dataset, "recorder": rec} else: - rconf = {"recorder": recorder} - record["kwargs"].update(rconf) - ar = init_instance_by_config(record) - ar.generate() + rconf = {"recorder": rec} + r = cls(**kwargs, **rconf) + r.generate() + + return rec + + +def task_train(task_config: dict, experiment_name: str) -> Recorder: + """ + Task based training, will be divided into two steps. + + Parameters + ---------- + task_config : dict + The config of a task. + experiment_name: str + The name of experiment + + Returns + ---------- + Recorder: The instance of the recorder + """ + recorder = begin_task_train(task_config, experiment_name) + recorder = end_task_train(recorder, experiment_name) + return recorder + + +class Trainer: + """ + The trainer can train a list of models. + There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training. + """ + + def __init__(self): + self.delay = False + + def train(self, tasks: list, *args, **kwargs) -> list: + """ + Given a list of task definitions, begin training, and return the models. + + For Trainer, it finishes real training in this method. + For DelayTrainer, it only does some preparation in this method. + + Args: + tasks: a list of tasks + + Returns: + list: a list of models + """ + raise NotImplementedError(f"Please implement the `train` method.") + + def end_train(self, models: list, *args, **kwargs) -> list: + """ + Given a list of models, finished something at the end of training if you need. + The models may be Recorder, txt file, database, and so on. + + For Trainer, it does some finishing touches in this method. + For DelayTrainer, it finishes real training in this method. + + Args: + models: a list of models + + Returns: + list: a list of models + """ + # do nothing if you finished all work in `train` method + return models + + def is_delay(self) -> bool: + """ + If Trainer will delay finishing `end_train`. + + Returns: + bool: if DelayTrainer + """ + return self.delay + + +class TrainerR(Trainer): + """ + Trainer based on (R)ecorder. + It will train a list of tasks and return a list of model recorders in a linear way. + + Assumption: models were defined by `task` and the results will be saved to `Recorder`. + """ + + # Those tag will help you distinguish whether the Recorder has finished traning + STATUS_KEY = "train_status" + STATUS_BEGIN = "begin_task_train" + STATUS_END = "end_task_train" + + def __init__(self, experiment_name: str = None, train_func: Callable = task_train): + """ + Init TrainerR. + + Args: + experiment_name (str, optional): the default name of experiment. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() + self.experiment_name = experiment_name + self.train_func = train_func + + def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + + Args: + tasks (list): a list of definitions based on `task` dict + train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method. + experiment_name (str): the experiment name, None for use default name. + kwargs: the params for train_func. + + Returns: + List[Recorder]: a list of Recorders + """ + if len(tasks) == 0: + return [] + if train_func is None: + train_func = self.train_func + if experiment_name is None: + experiment_name = self.experiment_name + recs = [] + for task in tasks: + rec = train_func(task, experiment_name, **kwargs) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) + recs.append(rec) + return recs + + def end_train(self, recs: list, **kwargs) -> List[Recorder]: + """ + Set STATUS_END tag to the recorders. + + Args: + recs (list): a list of trained recorders. + + Returns: + List[Recorder]: the same list as the param. + """ + for rec in recs: + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) + return recs + + +class DelayTrainerR(TrainerR): + """ + A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. + """ + + def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train): + """ + Init TrainerRM. + + Args: + experiment_name (str): the default name of experiment. + train_func (Callable, optional): default train method. Defaults to `begin_task_train`. + end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + """ + super().__init__(experiment_name, train_func) + self.end_train_func = end_train_func + self.delay = True + + def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finish real data loading and model fitting. + + Args: + recs (list): a list of Recorder, the tasks have been saved to them + end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + experiment_name (str): the experiment name, None for use default name. + kwargs: the params for end_train_func. + + Returns: + List[Recorder]: a list of Recorders + """ + if end_train_func is None: + end_train_func = self.end_train_func + if experiment_name is None: + experiment_name = self.experiment_name + for rec in recs: + if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END: + continue + end_train_func(rec, experiment_name, **kwargs) + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) + return recs + + +class TrainerRM(Trainer): + """ + Trainer based on (R)ecorder and Task(M)anager. + It can train a list of tasks and return a list of model recorders in a multiprocessing way. + + Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager + """ + + # Those tag will help you distinguish whether the Recorder has finished traning + STATUS_KEY = "train_status" + STATUS_BEGIN = "begin_task_train" + STATUS_END = "end_task_train" + + def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train): + """ + Init TrainerR. + + Args: + experiment_name (str): the default name of experiment. + task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() + self.experiment_name = experiment_name + self.task_pool = task_pool + self.train_func = train_func + + def train( + self, + tasks: list, + train_func: Callable = None, + experiment_name: str = None, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, + **kwargs, + ) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + + This method defaults to a single process, but TaskManager offered a great way to parallel training. + Users can customize their train_func to realize multiple processes or even multiple machines. + + Args: + tasks (list): a list of definitions based on `task` dict + train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method. + experiment_name (str): the experiment name, None for use default name. + before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs: the params for train_func. + + Returns: + List[Recorder]: a list of Recorders + """ + if len(tasks) == 0: + return [] + if train_func is None: + train_func = self.train_func + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + tm = TaskManager(task_pool=task_pool) + _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB + run_task( + train_func, + task_pool, + experiment_name=experiment_name, + before_status=before_status, + after_status=after_status, + **kwargs, + ) + + recs = [] + for _id in _id_list: + rec = tm.re_query(_id)["res"] + rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) + recs.append(rec) + return recs + + def end_train(self, recs: list, **kwargs) -> List[Recorder]: + """ + Set STATUS_END tag to the recorders. + + Args: + recs (list): a list of trained recorders. + + Returns: + List[Recorder]: the same list as the param. + """ + for rec in recs: + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) + return recs + + +class DelayTrainerRM(TrainerRM): + """ + A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. + + """ + + def __init__( + self, + experiment_name: str = None, + task_pool: str = None, + train_func=begin_task_train, + end_train_func=end_task_train, + ): + """ + Init DelayTrainerRM. + + Args: + experiment_name (str): the default name of experiment. + task_pool (str): task pool name in TaskManager. None for use same name as experiment_name. + train_func (Callable, optional): default train method. Defaults to `begin_task_train`. + end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + """ + super().__init__(experiment_name, task_pool, train_func) + self.end_train_func = end_train_func + self.delay = True + + def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + """ + Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE. + + Args: + tasks (list): a list of definition based on `task` dict + train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func. + experiment_name (str): the experiment name, None for use default name. + + Returns: + List[Recorder]: a list of Recorders + """ + if len(tasks) == 0: + return [] + return super().train( + tasks, + train_func=train_func, + experiment_name=experiment_name, + after_status=TaskManager.STATUS_PART_DONE, + **kwargs, + ) + + def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finish real data loading and model fitting. + + NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``. + + Args: + recs (list): a list of Recorder, the tasks have been saved to them. + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + experiment_name (str): the experiment name, None for use default name. + kwargs: the params for end_train_func. + + Returns: + List[Recorder]: a list of Recorders + """ + + if end_train_func is None: + end_train_func = self.end_train_func + if experiment_name is None: + experiment_name = self.experiment_name + task_pool = self.task_pool + if task_pool is None: + task_pool = experiment_name + tasks = [] + for rec in recs: + tasks.append(rec.load_object("task")) + + run_task( + end_train_func, + task_pool, + query={"filter": {"$in": tasks}}, # only train these tasks + experiment_name=experiment_name, + before_status=TaskManager.STATUS_PART_DONE, + **kwargs, + ) + for rec in recs: + rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) + return recs diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 1ee6f07a1f..77857182d9 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -6,6 +6,7 @@ from __future__ import print_function import os +import pickle import re import copy import json @@ -24,7 +25,9 @@ import numpy as np import pandas as pd from pathlib import Path -from typing import Union, Tuple, Text, Optional +from typing import Union, Tuple, Any, Text, Optional +from types import ModuleType +from urllib.parse import urlparse from ..config import C from ..log import get_module_logger, set_log_with_config @@ -165,24 +168,25 @@ def parse_field(field): return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field)) -def get_module_by_module_path(module_path): +def get_module_by_module_path(module_path: Union[str, ModuleType]): """Load module path :param module_path: :return: """ - - if module_path.endswith(".py"): - module_spec = importlib.util.spec_from_file_location("", module_path) - module = importlib.util.module_from_spec(module_spec) - module_spec.loader.exec_module(module) + if isinstance(module_path, ModuleType): + module = module_path else: - module = importlib.import_module(module_path) - + if module_path.endswith(".py"): + module_spec = importlib.util.spec_from_file_location("", module_path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + else: + module = importlib.import_module(module_path) return module -def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): +def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict): """ extract class and kwargs from config info @@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): config : [dict, str] similar to config - module : Python module + default_module : Python module or str It should be a python module to load the class type + This function will load class from the config['module_path'] first. + If config['module_path'] doesn't exists, it will load the class from default_module. Returns ------- @@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): the class object and it's arguments. """ if isinstance(config, dict): + module = get_module_by_module_path(config.get("module_path", default_module)) + # raise AttributeError klass = getattr(module, config["class"]) kwargs = config.get("kwargs", {}) elif isinstance(config, str): + module = get_module_by_module_path(default_module) + klass = getattr(module, config) kwargs = {} else: @@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): def init_instance_by_config( - config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs -) -> object: + config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs +) -> Any: """ get initialized instance with config @@ -227,13 +237,19 @@ def init_instance_by_config( 'model_path': path, # It is optional if module is given } str example. - "ClassName": getattr(module, config)() will be used. + 1) specify a pickle object + - path like 'file:////obj.pkl' + 2) specify a class name + - "ClassName": getattr(module, config)() will be used. object example: instance of accept_types - module : Python module + default_module : Python module Optional. It should be a python module. NOTE: the "module_path" will be override by `module` arguments + This function will load class from the config['module_path'] first. + If config['module_path'] doesn't exists, it will load the class from default_module. + accept_types: Union[type, Tuple[type]] Optional. If the config is a instance of specific type, return the config directly. This will be passed into the second parameter of isinstance. @@ -246,10 +262,14 @@ def init_instance_by_config( if isinstance(config, accept_types): return config - if module is None: - module = get_module_by_module_path(config["module_path"]) + if isinstance(config, str): + # path like 'file:////obj.pkl' + pr = urlparse(config) + if pr.scheme == "file": + with open(os.path.join(pr.netloc, pr.path), "rb") as f: + return pickle.load(f) - klass, cls_kwargs = get_cls_kwargs(config, module) + klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module) return klass(**cls_kwargs, **kwargs) @@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): return calendar -def get_date_by_shift(trading_date, shift, future=False, clip_shift=True): +def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"): """get trading date with shift bias wil cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True): """ from qlib.data import D - cal = D.calendar(future=future) + cal = D.calendar(future=future, freq=freq) if pd.to_datetime(trading_date) not in list(cal): raise ValueError("{} is not trading day!".format(str(trading_date))) _index = bisect.bisect_left(cal, trading_date) @@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: return df.sort_index(axis=axis) -def flatten_dict(d, parent_key="", sep="."): - """flatten_dict. +FLATTEN_TUPLE = "_FLATTEN_TUPLE" + + +def flatten_dict(d, parent_key="", sep=".") -> dict: + """ + Flatten a nested dict. + >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}) >>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10} - Parameters - ---------- - d : - d - parent_key : - parent_key - sep : - sep + >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE) + >>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10} + + Args: + d (dict): the dict waiting for flatting + parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "". + sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting. + + Returns: + dict: flatten dict """ items = [] for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k + if sep == FLATTEN_TUPLE: + new_key = (parent_key, k) if parent_key else k + else: + new_key = parent_key + sep + k if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_dict(v, new_key, sep=sep).items()) else: diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 4bc57eccd5..263e632deb 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -3,16 +3,24 @@ from pathlib import Path import pickle +import typing +import dill +from typing import Union class Serializable: """ - Serializable behaves like pickle. - But it only saves the state whose name **does not** start with `_` + Serializable will change the behaviors of pickle. + - It only saves the state whose name **does not** start with `_` + It provides a syntactic sugar for distinguish the attributes which user doesn't want. + - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk """ + pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python. + default_dump_all = False # if dump all things + def __init__(self): - self._dump_all = False + self._dump_all = self.default_dump_all self._exclude = [] def __getstate__(self) -> dict: @@ -33,18 +41,86 @@ def dump_all(self): @property def exclude(self): """ - What attribute will be dumped + What attribute will not be dumped """ return getattr(self, "_exclude", []) - def config(self, dump_all: bool = None, exclude: list = None): - if dump_all is not None: - self._dump_all = dump_all + FLAG_KEY = "_qlib_serial_flag" + + def config(self, dump_all: bool = None, exclude: list = None, recursive=False): + """ + configure the serializable object + + Parameters + ---------- + dump_all : bool + will the object dump all object + exclude : list + What attribute will not be dumped + recursive : bool + will the configuration be recursive + """ + + params = {"dump_all": dump_all, "exclude": exclude} + + for k, v in params.items(): + if v is not None: + attr_name = f"_{k}" + setattr(self, attr_name, v) - if exclude is not None: - self._exclude = exclude + if recursive: + for obj in self.__dict__.values(): + # set flag to prevent endless loop + self.__dict__[self.FLAG_KEY] = True + if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__: + obj.config(**params, recursive=True) + del self.__dict__[self.FLAG_KEY] - def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None): + def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None): + """ + Dump self to a pickle file. + + Args: + path (Union[Path, str]): the path to dump + dump_all (bool, optional): if need to dump all things. Defaults to None. + exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None. + """ self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: - pickle.dump(self, f) + self.get_backend().dump(self, f) + + @classmethod + def load(cls, filepath): + """ + Load the collector from a filepath. + + Args: + filepath (str): the path of file + + Raises: + TypeError: the pickled file must be `Collector` + + Returns: + Collector: the instance of Collector + """ + with open(filepath, "rb") as f: + object = cls.get_backend().load(f) + if isinstance(object, cls): + return object + else: + raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!") + + @classmethod + def get_backend(cls): + """ + Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill". + + Returns: + module: pickle or dill module based on pickle_backend + """ + if cls.pickle_backend == "pickle": + return pickle + elif cls.pickle_backend == "dill": + return dill + else: + raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.") diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 8135bab60a..2b2535edca 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -331,7 +331,7 @@ def set_uri(self, uri: Optional[Text]): """ self.exp_manager.set_uri(uri) - def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None): + def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder: """ Method for retrieving a recorder. diff --git a/qlib/workflow/online/__init__.py b/qlib/workflow/online/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py new file mode 100644 index 0000000000..443cd61ad8 --- /dev/null +++ b/qlib/workflow/online/manager.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically. + +With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models. +In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated. +So this module provides a series of methods to control this process. + +This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history. +Which means you can verify your strategy or find a better one. + +There are 4 total situations for using different trainers in different situations: + + + +========================= =================================================================================== +Situations Description +========================= =================================================================================== +Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models. + +Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models + in this routine. So it is not necessary to use DelayTrainer when do a REAL routine. + +Simulation + Trainer When your models have some temporal dependence on the previous models, then you + need to consider using Trainer. This means it will REAL train your models in + every routine and prepare signals for every routine. + +Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer + for the ability to multitasking. It means all tasks in all routines + can be REAL trained at the end of simulating. The signals will be prepared well at + different time segments (based on whether or not any new model is online). +========================= =================================================================================== +""" + +import logging +from typing import Callable, Dict, List, Union + +import pandas as pd +from qlib import get_module_logger +from qlib.data.data import D +from qlib.log import set_global_logger_level +from qlib.model.ens.ensemble import AverageEnsemble +from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR +from qlib.utils import flatten_dict +from qlib.utils.serial import Serializable +from qlib.workflow.online.strategy import OnlineStrategy +from qlib.workflow.task.collect import MergeCollector + + +class OnlineManager(Serializable): + """ + OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_. + It also provides a history recording of which models are online at what time. + """ + + STATUS_SIMULATING = "simulating" # when calling `simulate` + STATUS_NORMAL = "normal" # the normal status + + def __init__( + self, + strategies: Union[OnlineStrategy, List[OnlineStrategy]], + trainer: Trainer = None, + begin_time: Union[str, pd.Timestamp] = None, + freq="day", + ): + """ + Init OnlineManager. + One OnlineManager must have at least one OnlineStrategy. + + Args: + strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy + begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date. + trainer (Trainer): the trainer to train task. None for using TrainerR. + freq (str, optional): data frequency. Defaults to "day". + """ + self.logger = get_module_logger(self.__class__.__name__) + if not isinstance(strategies, list): + strategies = [strategies] + self.strategies = strategies + self.freq = freq + if begin_time is None: + begin_time = D.calendar(freq=self.freq).max() + self.begin_time = pd.Timestamp(begin_time) + self.cur_time = self.begin_time + # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}. + self.history = {} + if trainer is None: + trainer = TrainerR() + self.trainer = trainer + self.signals = None + self.status = self.STATUS_NORMAL + + def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}): + """ + Get tasks from every strategy's first_tasks method and train them. + If using DelayTrainer, it can finish training all together after every strategy's first_tasks. + + Args: + strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies. + model_kwargs (dict): the params for `prepare_online_models` + """ + if strategies is None: + strategies = self.strategies + for strategy in strategies: + + self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") + tasks = strategy.first_tasks() + models = self.trainer.train(tasks, experiment_name=strategy.name_id) + models = self.trainer.end_train(models, experiment_name=strategy.name_id) + self.logger.info(f"Finished training {len(models)} models.") + + online_models = strategy.prepare_online_models(models, **model_kwargs) + self.history.setdefault(self.cur_time, {})[strategy] = online_models + + def routine( + self, + cur_time: Union[str, pd.Timestamp] = None, + task_kwargs: dict = {}, + model_kwargs: dict = {}, + signal_kwargs: dict = {}, + ): + """ + Typical update process for every strategy and record the online history. + + The typical update process after a routine, such as day by day or month by month. + The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals. + + If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks. + + Args: + cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None. + task_kwargs (dict): the params for `prepare_tasks` + model_kwargs (dict): the params for `prepare_online_models` + signal_kwargs (dict): the params for `prepare_signals` + """ + if cur_time is None: + cur_time = D.calendar(freq=self.freq).max() + self.cur_time = pd.Timestamp(cur_time) # None for latest date + + for strategy in self.strategies: + self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") + if self.status == self.STATUS_NORMAL: + strategy.tool.update_online_pred() + + tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) + models = self.trainer.train(tasks) + if self.status == self.STATUS_NORMAL or not self.trainer.is_delay(): + models = self.trainer.end_train(models, experiment_name=strategy.name_id) + self.logger.info(f"Finished training {len(models)} models.") + online_models = strategy.prepare_online_models(models, **model_kwargs) + self.history.setdefault(self.cur_time, {})[strategy] = online_models + + if not self.trainer.is_delay(): + self.prepare_signals(**signal_kwargs) + + def get_collector(self) -> MergeCollector: + """ + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. + This collector can be a basis as the signals preparation. + + Returns: + MergeCollector: the collector to merge other collectors. + """ + collector_dict = {} + for strategy in self.strategies: + collector_dict[strategy.name_id] = strategy.get_collector() + return MergeCollector(collector_dict, process_list=[]) + + def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]): + """ + Add some new strategies to OnlineManager. + + Args: + strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy + """ + if not isinstance(strategies, list): + strategies = [strategies] + self.first_train(strategies) + self.strategies.extend(strategies) + + def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False): + """ + After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine. + + NOTE: Given a set prediction, all signals before these prediction end times will be prepared well. + + Even if the latest signal already exists, the latest calculation result will be overwritten. + + .. note:: + + Given a prediction of a certain time, all signals before this time will be prepared well. + + Args: + prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}. + over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False. + + Returns: + pd.DataFrame: the signals. + """ + signals = prepare_func(self.get_collector()()) + old_signals = self.signals + if old_signals is not None and not over_write: + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + self.logger.info(f"Finished preparing new {len(new_signals)} signals.") + self.signals = signals + return new_signals + + def get_signals(self) -> Union[pd.Series, pd.DataFrame]: + """ + Get prepared online signals. + + Returns: + Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime. + pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals. + """ + return self.signals + + SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information + SIM_LOG_NAME = "SIMULATE_INFO" + + def simulate( + self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={} + ) -> Union[pd.Series, pd.DataFrame]: + """ + Starting from the current time, this method will simulate every routine in OnlineManager until the end time. + + Considering the parallel training, the models and signals can be prepared after all routine simulating. + + The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``. + + Args: + end_time: the time the simulation will end + frequency: the calendar frequency + task_kwargs (dict): the params for `prepare_tasks` + model_kwargs (dict): the params for `prepare_online_models` + signal_kwargs (dict): the params for `prepare_signals` + + Returns: + Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime. + pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals. + """ + self.status = self.STATUS_SIMULATING + cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency) + self.first_train() + + simulate_level = self.SIM_LOG_LEVEL + set_global_logger_level(simulate_level) + logging.addLevelName(simulate_level, self.SIM_LOG_NAME) + + for cur_time in cal: + self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......") + self.routine( + cur_time, + task_kwargs=task_kwargs, + model_kwargs=model_kwargs, + signal_kwargs=signal_kwargs, + ) + # delay prepare the models and signals + if self.trainer.is_delay(): + self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs) + + # FIXME: get logging level firstly and restore it here + set_global_logger_level(logging.DEBUG) + self.logger.info(f"Finished preparing signals") + self.status = self.STATUS_NORMAL + return self.get_signals() + + def delay_prepare(self, model_kwargs={}, signal_kwargs={}): + """ + Prepare all models and signals if something is waiting for preparation. + + Args: + model_kwargs: the params for `end_train` + signal_kwargs: the params for `prepare_signals` + """ + last_models = {} + signals_time = D.calendar()[0] + need_prepare = False + for cur_time, strategy_models in self.history.items(): + self.cur_time = cur_time + + for strategy, models in strategy_models.items(): + # only new online models need to prepare + if last_models.setdefault(strategy, set()) != set(models): + models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs) + strategy.tool.reset_online_tag(models) + need_prepare = True + last_models[strategy] = set(models) + + if need_prepare: + # NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way. + self.prepare_signals(**signal_kwargs) + if signals_time > cur_time: + self.logger.warn( + f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models." + ) + need_prepare = False + signals_time = self.signals.index.get_level_values("datetime").max() diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py new file mode 100644 index 0000000000..a54eb32bfe --- /dev/null +++ b/qlib/workflow/online/strategy.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +OnlineStrategy module is an element of online serving. +""" + +from copy import deepcopy +from typing import List, Tuple, Union +from qlib.data.data import D +from qlib.log import get_module_logger +from qlib.model.ens.group import RollingGroup +from qlib.workflow.online.utils import OnlineTool, OnlineToolR +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.collect import Collector, RecorderCollector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.utils import TimeAdjuster + + +class OnlineStrategy: + """ + OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared. + """ + + def __init__(self, name_id: str): + """ + Init OnlineStrategy. + This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training. + + Args: + name_id (str): a unique name or id. + trainer (Trainer, optional): a instance of Trainer. Defaults to None. + """ + self.name_id = name_id + self.logger = get_module_logger(self.__class__.__name__) + self.tool = OnlineTool() + + def prepare_tasks(self, cur_time, **kwargs) -> List[dict]: + """ + After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest).. + Return the new tasks waiting for training. + + You can find the last online models by OnlineTool.online_models. + """ + raise NotImplementedError(f"Please implement the `prepare_tasks` method.") + + def prepare_online_models(self, trained_models, cur_time=None) -> List[object]: + """ + Select some models from trained models and set them to online models. + This is a typical implementation to online all trained models, you can override it to implement the complex method. + You can find the last online models by OnlineTool.online_models if you still need them. + + NOTE: Reset all online models to trained models. If there are no trained models, then do nothing. + + Args: + models (list): a list of models. + cur_time (pd.Dataframe): current time from OnlineManger. None for the latest. + + Returns: + List[object]: a list of online models. + """ + if not trained_models: + return self.tool.online_models() + self.tool.reset_online_tag(trained_models) + return trained_models + + def first_tasks(self) -> List[dict]: + """ + Generate a series of tasks firstly and return them. + """ + raise NotImplementedError(f"Please implement the `first_tasks` method.") + + def get_collector(self) -> Collector: + """ + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy. + + For example: + 1) collect predictions in Recorder + 2) collect signals in a txt file + + Returns: + Collector + """ + raise NotImplementedError(f"Please implement the `get_collector` method.") + + +class RollingStrategy(OnlineStrategy): + + """ + This example strategy always uses the latest rolling model sas online models. + """ + + def __init__( + self, + name_id: str, + task_template: Union[dict, List[dict]], + rolling_gen: RollingGen, + ): + """ + Init RollingStrategy. + + Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same. + + Args: + name_id (str): a unique name or id. Will be also the name of the Experiment. + task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. + rolling_gen (RollingGen): an instance of RollingGen + """ + super().__init__(name_id=name_id) + self.exp_name = self.name_id + if not isinstance(task_template, list): + task_template = [task_template] + self.task_template = task_template + self.rg = rolling_gen + self.tool = OnlineToolR(self.exp_name) + self.ta = TimeAdjuster() + + def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): + """ + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models. + + Assumption: the models can be distinguished based on the model name and rolling test segments. + If you do not want this assumption, please implement your method or use another rec_key_func. + + Args: + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts. + """ + + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key + + if rec_key_func is None: + rec_key_func = rec_key + + artifacts_collector = RecorderCollector( + experiment=self.exp_name, + process_list=process_list, + rec_key_func=rec_key_func, + rec_filter_func=rec_filter_func, + artifacts_key=artifacts_key, + ) + + return artifacts_collector + + def first_tasks(self) -> List[dict]: + """ + Use rolling_gen to generate different tasks based on task_template. + + Returns: + List[dict]: a list of tasks + """ + return task_generator( + tasks=self.task_template, + generators=self.rg, # generate different date segment + ) + + def prepare_tasks(self, cur_time) -> List[dict]: + """ + Prepare new tasks based on cur_time (None for the latest). + + You can find the last online models by OnlineToolR.online_models. + + Returns: + List[dict]: a list of new tasks. + """ + latest_records, max_test = self._list_latest(self.tool.online_models()) + if max_test is None: + self.logger.warn(f"No latest online recorders, no new tasks.") + return [] + calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time + self.logger.info( + f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" + ) + if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: + old_tasks = [] + tasks_tmp = [] + for rec in latest_records: + task = rec.load_object("task") + old_tasks.append(deepcopy(task)) + test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] + # modify the test segment to generate new tasks + task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) + tasks_tmp.append(task) + new_tasks_tmp = task_generator(tasks_tmp, self.rg) + new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] + return new_tasks + return [] + + def _list_latest(self, rec_list: List[Recorder]): + """ + List latest recorder form rec_list + + Args: + rec_list (List[Recorder]): a list of Recorder + + Returns: + List[Recorder], pd.Timestamp: the latest recorders and their test end time + """ + if len(rec_list) == 0: + return rec_list, None + max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) + latest_rec = [] + for rec in rec_list: + if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: + latest_rec.append(rec) + return latest_rec, max_test diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py new file mode 100644 index 0000000000..561f7e18ae --- /dev/null +++ b/qlib/workflow/online/update.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Updater is a module to update artifacts such as predictions when the stock data is updating. +""" + +from abc import ABCMeta, abstractmethod + +import pandas as pd +from qlib import get_module_logger +from qlib.data import D +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +from qlib.model import Model +from qlib.utils import get_date_by_shift +from qlib.workflow.recorder import Recorder + + +class RMDLoader: + """ + Recorder Model Dataset Loader + """ + + def __init__(self, rec: Recorder): + self.rec = rec + + def get_dataset(self, start_time, end_time, segments=None) -> DatasetH: + """ + Load, config and setup dataset. + + This dataset is for inference. + + Args: + start_time : + the start_time of underlying data + end_time : + the end_time of underlying data + segments : dict + the segments config for dataset + Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time + + Returns: + DatasetH: the instance of DatasetH + + """ + if segments is None: + segments = {"test": (start_time, end_time)} + dataset: DatasetH = self.rec.load_object("dataset") + dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments) + dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}) + return dataset + + def get_model(self) -> Model: + return self.rec.load_object("params.pkl") + + +class RecordUpdater(metaclass=ABCMeta): + """ + Update a specific recorders + """ + + def __init__(self, record: Recorder, *args, **kwargs): + self.record = record + self.logger = get_module_logger(self.__class__.__name__) + + @abstractmethod + def update(self, *args, **kwargs): + """ + Update info for specific recorder + """ + ... + + +class PredUpdater(RecordUpdater): + """ + Update the prediction in the Recorder + """ + + def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"): + """ + Init PredUpdater. + + Args: + record : Recorder + to_date : + update to prediction to the `to_date` + hist_ref : int + Sometimes, the dataset will have historical depends. + Leave the problem to users to set the length of historical dependency + + .. note:: + + the start_time is not included in the hist_ref + + """ + # TODO: automate this hist_ref in the future. + super().__init__(record=record) + + self.to_date = to_date + self.hist_ref = hist_ref + self.freq = freq + self.rmdl = RMDLoader(rec=record) + + if to_date == None: + to_date = D.calendar(freq=freq)[-1] + self.to_date = pd.Timestamp(to_date) + self.old_pred = record.load_object("pred.pkl") + self.last_end = self.old_pred.index.get_level_values("datetime").max() + + def prepare_data(self) -> DatasetH: + """ + Load dataset + + Separating this function will make it easier to reuse the dataset + + Returns: + DatasetH: the instance of DatasetH + """ + start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq) + start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) + seg = {"test": (start_time, self.to_date)} + dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg) + return dataset + + def update(self, dataset: DatasetH = None): + """ + Update the prediction in a recorder. + + Args: + DatasetH: the instance of DatasetH. None for reprepare. + """ + # FIXME: the problem below is not solved + # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised + # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. + # https://github.com/pytorch/pytorch/issues/16797 + + start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) + if start_time >= self.to_date: + self.logger.info( + f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}." + ) + return + + # load dataset + if dataset is None: + # For reusing the dataset + dataset = self.prepare_data() + + # Load model + model = self.rmdl.get_model() + + new_pred: pd.Series = model.predict(dataset) + + cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0) + cb_pred = cb_pred.sort_index() + + self.record.save_objects(**{"pred.pkl": cb_pred}) + + self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.") diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py new file mode 100644 index 0000000000..f3ef13aa93 --- /dev/null +++ b/qlib/workflow/online/utils.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +OnlineTool is a module to set and unset a series of `online` models. +The `online` models are some decisive models in some time points, which can be changed with the change of time. +This allows us to use efficient submodels as the market-style changing. +""" + +from typing import List, Union + +from qlib.log import get_module_logger +from qlib.workflow.online.update import PredUpdater +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.utils import list_recorders + + +class OnlineTool: + """ + OnlineTool will manage `online` models in an experiment that includes the model recorders. + """ + + ONLINE_KEY = "online_status" # the online status key in recorder + ONLINE_TAG = "online" # the 'online' model + OFFLINE_TAG = "offline" # the 'offline' model, not for online serving + + def __init__(self): + """ + Init OnlineTool. + """ + self.logger = get_module_logger(self.__class__.__name__) + + def set_online_tag(self, tag, recorder: Union[list, object]): + """ + Set `tag` to the model to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[list,object]): the model's recorder + """ + raise NotImplementedError(f"Please implement the `set_online_tag` method.") + + def get_online_tag(self, recorder: object) -> str: + """ + Given a model recorder and return its online tag. + + Args: + recorder (Object): the model's recorder + + Returns: + str: the online tag + """ + raise NotImplementedError(f"Please implement the `get_online_tag` method.") + + def reset_online_tag(self, recorder: Union[list, object]): + """ + Offline all models and set the recorders to 'online'. + + Args: + recorder (Union[list,object]): + the recorder you want to reset to 'online'. + + """ + raise NotImplementedError(f"Please implement the `reset_online_tag` method.") + + def online_models(self) -> list: + """ + Get current `online` models + + Returns: + list: a list of `online` models. + """ + raise NotImplementedError(f"Please implement the `online_models` method.") + + def update_online_pred(self, to_date=None): + """ + Update the predictions of `online` models to to_date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest. + + """ + raise NotImplementedError(f"Please implement the `update_online_pred` method.") + + +class OnlineToolR(OnlineTool): + """ + The implementation of OnlineTool based on (R)ecorder. + """ + + def __init__(self, experiment_name: str): + """ + Init OnlineToolR. + + Args: + experiment_name (str): the experiment name. + """ + super().__init__() + self.exp_name = experiment_name + + def set_online_tag(self, tag, recorder: Union[Recorder, List]): + """ + Set `tag` to the model's recorder to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder + """ + if isinstance(recorder, Recorder): + recorder = [recorder] + for rec in recorder: + rec.set_tags(**{self.ONLINE_KEY: tag}) + self.logger.info(f"Set {len(recorder)} models to '{tag}'.") + + def get_online_tag(self, recorder: Recorder) -> str: + """ + Given a model recorder and return its online tag. + + Args: + recorder (Recorder): an instance of recorder + + Returns: + str: the online tag + """ + tags = recorder.list_tags() + return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG) + + def reset_online_tag(self, recorder: Union[Recorder, List]): + """ + Offline all models and set the recorders to 'online'. + + Args: + recorder (Union[Recorder, List]): + the recorder you want to reset to 'online'. + + """ + if isinstance(recorder, Recorder): + recorder = [recorder] + recs = list_recorders(self.exp_name) + self.set_online_tag(self.OFFLINE_TAG, list(recs.values())) + self.set_online_tag(self.ONLINE_TAG, recorder) + + def online_models(self) -> list: + """ + Get current `online` models + + Returns: + list: a list of `online` models. + """ + return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) + + def update_online_pred(self, to_date=None): + """ + Update the predictions of online models to to_date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar. + """ + online_models = self.online_models() + for rec in online_models: + hist_ref = 0 + task = rec.load_object("task") + # Special treatment of historical dependencies + if task["dataset"]["class"] == "TSDatasetH": + hist_ref = task["dataset"]["kwargs"]["step_len"] + PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update() + + self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 5732c95a9e..fc71b3f9a2 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -151,6 +151,10 @@ def generate(self, **kwargs): del params["data_key"] # The backend handler should be DataHandler raw_label = self.dataset.prepare(**params) + except AttributeError: + # The data handler is initialize with `drop_raw=True`... + # So raw_label is not available + raw_label = None self.recorder.save_objects(**{"label.pkl": raw_label}) self.dataset.__class__ = orig_cls @@ -236,6 +240,9 @@ def generate(self, **kwargs): pred = self.load("pred.pkl") label = self.load("label.pkl") + if label is None or not isinstance(label, pd.DataFrame) or label.empty: + logger.warn(f"Empty label.") + return ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0]) metrics = { "IC": ic.mean(), diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index b9b2fd1b36..0c9abf7318 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -39,6 +39,9 @@ def __repr__(self): def __str__(self): return str(self.info) + def __hash__(self) -> int: + return hash(self.info["id"]) + @property def info(self): output = dict() @@ -232,6 +235,14 @@ def __repr__(self): client=self.client, ) + def __hash__(self) -> int: + return hash(self.info["id"]) + + def __eq__(self, o: object) -> bool: + if isinstance(o, MLflowRecorder): + return self.info["id"] == o.info["id"] + return False + @property def uri(self): return self._uri diff --git a/qlib/workflow/task/__init__.py b/qlib/workflow/task/__init__.py new file mode 100644 index 0000000000..cc338cca4d --- /dev/null +++ b/qlib/workflow/task/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Task related workflow is implemented in this folder + +A typical task workflow + +| Step | Description | +|-----------------------+------------------------------------------------| +| TaskGen | Generating tasks. | +| TaskManager(optional) | Manage generated tasks | +| run task | retrive tasks from TaskManager and run tasks. | +""" diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py new file mode 100644 index 0000000000..9410c2b9c2 --- /dev/null +++ b/qlib/workflow/task/collect.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on. +""" + +from typing import Callable, Dict, List +from qlib.utils.serial import Serializable +from qlib.workflow import R + + +class Collector(Serializable): + """The collector to collect different results""" + + pickle_backend = "dill" # use dill to dump user method + + def __init__(self, process_list=[]): + """ + Init Collector. + + Args: + process_list (list or Callable): the list of processors or the instance of a processor to process dict. + """ + if not isinstance(process_list, list): + process_list = [process_list] + self.process_list = process_list + + def collect(self) -> dict: + """ + Collect the results and return a dict like {key: things} + + Returns: + dict: the dict after collecting. + + For example: + + {"prediction": pd.Series} + + {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}} + + ...... + """ + raise NotImplementedError(f"Please implement the `collect` method.") + + @staticmethod + def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict: + """ + Do a series of processing to the dict returned by collect and return a dict like {key: things} + For example, you can group and ensemble. + + Args: + collected_dict (dict): the dict return by `collect` + process_list (list or Callable): the list of processors or the instance of a processor to process dict. + The processor order is the same as the list order. + For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] + + Returns: + dict: the dict after processing. + """ + if not isinstance(process_list, list): + process_list = [process_list] + result = {} + for artifact in collected_dict: + value = collected_dict[artifact] + for process in process_list: + if not callable(process): + raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.") + value = process(value, *args, **kwargs) + result[artifact] = value + return result + + def __call__(self, *args, **kwargs) -> dict: + """ + Do the workflow including ``collect`` and ``process_collect`` + + Returns: + dict: the dict after collecting and processing. + """ + collected = self.collect() + return self.process_collect(collected, self.process_list, *args, **kwargs) + + +class MergeCollector(Collector): + """ + A collector to collect the results of other Collectors + + For example: + + We have 2 collector, which named A and B. + A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}. + Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}} + + ...... + + """ + + def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None): + """ + Init MergeCollector. + + Args: + collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector} + process_list (List[Callable]): the list of processors or the instance of processor to process dict. + merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting. + None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")). + """ + super().__init__(process_list=process_list) + self.collector_dict = collector_dict + self.merge_func = merge_func + + def collect(self) -> dict: + """ + Collect all results of collector_dict and change the outermost key to a recombination key. + + Returns: + dict: the dict after collecting. + """ + collect_dict = {} + for collector_key, collector in self.collector_dict.items(): + tmp_dict = collector() + for key, value in tmp_dict.items(): + if self.merge_func is not None: + collect_dict[self.merge_func(collector_key, key)] = value + else: + collect_dict[(collector_key, key)] = value + return collect_dict + + +class RecorderCollector(Collector): + ART_KEY_RAW = "__raw" + + def __init__( + self, + experiment, + process_list=[], + rec_key_func=None, + rec_filter_func=None, + artifacts_path={"pred": "pred.pkl"}, + artifacts_key=None, + ): + """ + Init RecorderCollector. + + Args: + experiment (Experiment or str): an instance of an Experiment or the name of an Experiment + process_list (list or Callable): the list of processors or the instance of a processor to process dict. + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. + artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. + """ + super().__init__(process_list=process_list) + if isinstance(experiment, str): + experiment = R.get_exp(experiment_name=experiment) + self.experiment = experiment + self.artifacts_path = artifacts_path + if rec_key_func is None: + rec_key_func = lambda rec: rec.info["id"] + if artifacts_key is None: + artifacts_key = list(self.artifacts_path.keys()) + self.rec_key_func = rec_key_func + self.artifacts_key = artifacts_key + self.rec_filter_func = rec_filter_func + + def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict: + """ + Collect different artifacts based on recorder after filtering. + + Args: + artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default. + rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default. + only_exist (bool, optional): if only collect the artifacts when a recorder really has. + If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception. + + Returns: + dict: the dict after collected like {artifact: {rec_key: object}} + """ + if artifacts_key is None: + artifacts_key = self.artifacts_key + if rec_filter_func is None: + rec_filter_func = self.rec_filter_func + + if isinstance(artifacts_key, str): + artifacts_key = [artifacts_key] + + collect_dict = {} + # filter records + recs = self.experiment.list_recorders() + recs_flt = {} + for rid, rec in recs.items(): + if rec_filter_func is None or rec_filter_func(rec): + recs_flt[rid] = rec + + for _, rec in recs_flt.items(): + rec_key = self.rec_key_func(rec) + for key in artifacts_key: + if self.ART_KEY_RAW == key: + artifact = rec + else: + try: + artifact = rec.load_object(self.artifacts_path[key]) + except Exception as e: + if only_exist: + # only collect existing artifact + continue + raise e + collect_dict.setdefault(key, {})[rec_key] = artifact + + return collect_dict + + def get_exp_name(self) -> str: + """ + Get experiment name + + Returns: + str: experiment name + """ + return self.experiment.name diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py new file mode 100644 index 0000000000..cdebf50494 --- /dev/null +++ b/qlib/workflow/task/gen.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +TaskGenerator module can generate many tasks based on TaskGen and some task templates. +""" +import abc +import copy +from typing import List, Union, Callable +from .utils import TimeAdjuster + + +def task_generator(tasks, generators) -> list: + """ + Use a list of TaskGen and a list of task templates to generate different tasks. + + For examples: + + There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template. + task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks. + + Parameters + ---------- + tasks : List[dict] or dict + a list of task templates or a single task + generators : List[TaskGen] or TaskGen + a list of TaskGen or a single TaskGen + + Returns + ------- + list + a list of tasks + """ + + if isinstance(tasks, dict): + tasks = [tasks] + if isinstance(generators, TaskGen): + generators = [generators] + + # generate gen_task_list + for gen in generators: + new_task_list = [] + for task in tasks: + new_task_list.extend(gen.generate(task)) + tasks = new_task_list + + return tasks + + +class TaskGen(metaclass=abc.ABCMeta): + """ + The base class for generating different tasks + + Example 1: + + input: a specific task template and rolling steps + + output: rolling version of the tasks + + Example 2: + + input: a specific task template and losses list + + output: a set of tasks with different losses + + """ + + @abc.abstractmethod + def generate(self, task: dict) -> List[dict]: + """ + Generate different tasks based on a task template + + Parameters + ---------- + task: dict + a task template + + Returns + ------- + typing.List[dict]: + A list of tasks + """ + pass + + def __call__(self, *args, **kwargs): + """ + This is just a syntactic sugar for generate + """ + return self.generate(*args, **kwargs) + + +def handler_mod(task: dict, rolling_gen): + """ + Help to modify the handler end time when using RollingGen + + Args: + task (dict): a task template + rg (RollingGen): an instance of RollingGen + """ + try: + interval = rolling_gen.ta.cal_interval( + task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], + task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1], + ) + # if end_time < the end of test_segments, then change end_time to allow load more data + if interval < 0: + task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy( + task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1] + ) + except KeyError: + # Maybe dataset do not have handler, then do nothing. + pass + + +class RollingGen(TaskGen): + ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date + ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date + + def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod): + """ + Generate tasks for rolling + + Parameters + ---------- + step : int + step to rolling + rtype : str + rolling type (expanding, sliding) + ds_extra_mod_func: Callable + A method like: handler_mod(task: dict, rg: RollingGen) + Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset. + """ + self.step = step + self.rtype = rtype + self.ds_extra_mod_func = ds_extra_mod_func + self.ta = TimeAdjuster(future=True) + + self.test_key = "test" + self.train_key = "train" + + def generate(self, task: dict) -> List[dict]: + """ + Converting the task into a rolling task. + + Parameters + ---------- + task: dict + A dict describing a task. For example. + + .. code-block:: python + + DEFAULT_TASK = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi100", + }, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + "record": [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + ] + } + + Returns + ---------- + List[dict]: a list of tasks + """ + res = [] + + prev_seg = None + test_end = None + while True: + t = copy.deepcopy(task) + + # calculate segments + if prev_seg is None: + # First rolling + # 1) prepare the end point + segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) + test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1] + # 2) and init test segments + test_start_idx = self.ta.align_idx(segments[self.test_key][0]) + segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) + else: + segments = {} + try: + for k, seg in prev_seg.items(): + # decide how to shift + # expanding only for train data, the segments size of test data and valid data won't change + if k == self.train_key and self.rtype == self.ROLL_EX: + rtype = self.ta.SHIFT_EX + else: + rtype = self.ta.SHIFT_SD + # shift the segments data + segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) + if segments[self.test_key][0] > test_end: + break + except KeyError: + # We reach the end of tasks + # No more rolling + break + + # update segments of this task + t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) + prev_seg = segments + if self.ds_extra_mod_func is not None: + self.ds_extra_mod_func(t, self) + res.append(t) + return res diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py new file mode 100644 index 0000000000..658eec4d6e --- /dev/null +++ b/qlib/workflow/task/manage.py @@ -0,0 +1,493 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling. +These features can run tasks concurrently and ensure every task will be used only once. +Task Manager will store all tasks in `MongoDB `_. +Users **MUST** finished the configuration of `MongoDB `_ when using this module. + +A task in TaskManager consists of 3 parts +- tasks description: the desc will define the task +- tasks status: the status of the task +- tasks result: A user can get the task with the task description and task result. +""" +import concurrent +import pickle +import time +from contextlib import contextmanager +from typing import Callable, List + +import fire +import pymongo +from bson.binary import Binary +from bson.objectid import ObjectId +from pymongo.errors import InvalidDocument +from qlib import auto_init, get_module_logger +from tqdm.cli import tqdm + +from .utils import get_mongodb + + +class TaskManager: + """ + TaskManager + + Here is what will a task looks like when it created by TaskManager + + .. code-block:: python + + { + 'def': pickle serialized task definition. using pickle will make it easier + 'filter': json-like data. This is for filtering the tasks. + 'status': 'waiting' | 'running' | 'done' + 'res': pickle serialized task result, + } + + The tasks manager assumes that you will only update the tasks you fetched. + The mongo fetch one and update will make it date updating secure. + + .. note:: + + Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded + + Here are four status which are: + + STATUS_WAITING: waiting for training + + STATUS_RUNNING: training + + STATUS_PART_DONE: finished some step and waiting for next step + + STATUS_DONE: all work done + """ + + STATUS_WAITING = "waiting" + STATUS_RUNNING = "running" + STATUS_DONE = "done" + STATUS_PART_DONE = "part_done" + + ENCODE_FIELDS_PREFIX = ["def", "res"] + + def __init__(self, task_pool: str = None): + """ + Init Task Manager, remember to make the statement of MongoDB url and database name firstly. + + Parameters + ---------- + task_pool: str + the name of Collection in MongoDB + """ + self.mdb = get_mongodb() + if task_pool is not None: + self.task_pool = getattr(self.mdb, task_pool) + self.logger = get_module_logger(self.__class__.__name__) + + def list(self) -> list: + """ + List the all collection(task_pool) of the db + + Returns: + list + """ + return self.mdb.list_collection_names() + + def _encode_task(self, task): + for prefix in self.ENCODE_FIELDS_PREFIX: + for k in list(task.keys()): + if k.startswith(prefix): + task[k] = Binary(pickle.dumps(task[k])) + return task + + def _decode_task(self, task): + for prefix in self.ENCODE_FIELDS_PREFIX: + for k in list(task.keys()): + if k.startswith(prefix): + task[k] = pickle.loads(task[k]) + return task + + def _dict_to_str(self, flt): + return {k: str(v) for k, v in flt.items()} + + def replace_task(self, task, new_task): + """ + Use a new task to replace a old one + + Args: + task: old task + new_task: new task + """ + new_task = self._encode_task(new_task) + query = {"_id": ObjectId(task["_id"])} + try: + self.task_pool.replace_one(query, new_task) + except InvalidDocument: + task["filter"] = self._dict_to_str(task["filter"]) + self.task_pool.replace_one(query, new_task) + + def insert_task(self, task): + """ + Insert a task. + + Args: + task: the task waiting for insert + + Returns: + pymongo.results.InsertOneResult + """ + try: + insert_result = self.task_pool.insert_one(task) + except InvalidDocument: + task["filter"] = self._dict_to_str(task["filter"]) + insert_result = self.task_pool.insert_one(task) + return insert_result + + def insert_task_def(self, task_def): + """ + Insert a task to task_pool + + Parameters + ---------- + task_def: dict + the task definition + + Returns + ------- + pymongo.results.InsertOneResult + """ + task = self._encode_task( + { + "def": task_def, + "filter": task_def, # FIXME: catch the raised error + "status": self.STATUS_WAITING, + } + ) + insert_result = self.insert_task(task) + return insert_result + + def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]: + """ + If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id. + If a task is not new, then just query its _id. + + Parameters + ---------- + task_def_l: list + a list of task + dry_run: bool + if insert those new tasks to task pool + print_nt: bool + if print new task + + Returns + ------- + List[str] + a list of the _id of task_def_l + """ + new_tasks = [] + _id_list = [] + for t in task_def_l: + try: + r = self.task_pool.find_one({"filter": t}) + except InvalidDocument: + r = self.task_pool.find_one({"filter": self._dict_to_str(t)}) + if r is None: + new_tasks.append(t) + if not dry_run: + insert_result = self.insert_task_def(t) + _id_list.append(insert_result.inserted_id) + else: + _id_list.append(None) + else: + _id_list.append(self._decode_task(r)["_id"]) + + self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}") + + if print_nt: # print new task + for t in new_tasks: + print(t) + + if dry_run: + return [] + + return _id_list + + def fetch_task(self, query={}, status=STATUS_WAITING) -> dict: + """ + Use query to fetch tasks. + + Args: + query (dict, optional): query dict. Defaults to {}. + status (str, optional): [description]. Defaults to STATUS_WAITING. + + Returns: + dict: a task(document in collection) after decoding + """ + query = query.copy() + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + query.update({"status": status}) + task = self.task_pool.find_one_and_update( + query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] + ) + # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority + if task is None: + return None + task["status"] = self.STATUS_RUNNING + return self._decode_task(task) + + @contextmanager + def safe_fetch_task(self, query={}, status=STATUS_WAITING): + """ + Fetch task from task_pool using query with contextmanager + + Parameters + ---------- + query: dict + the dict of query + + Returns + ------- + dict: a task(document in collection) after decoding + """ + task = self.fetch_task(query=query, status=status) + try: + yield task + except Exception: + if task is not None: + self.logger.info("Returning task before raising error") + self.return_task(task) + self.logger.info("Task returned") + raise + + def task_fetcher_iter(self, query={}): + while True: + with self.safe_fetch_task(query=query) as task: + if task is None: + break + yield task + + def query(self, query={}, decode=True): + """ + Query task in collection. + This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator + + Parameters + ---------- + query: dict + the dict of query + decode: bool + + Returns + ------- + dict: a task(document in collection) after decoding + """ + query = query.copy() + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + for t in self.task_pool.find(query): + yield self._decode_task(t) + + def re_query(self, _id): + """ + Use _id to query task. + + Args: + _id (str): _id of a document + + Returns: + dict: a task(document in collection) after decoding + """ + t = self.task_pool.find_one({"_id": ObjectId(_id)}) + return self._decode_task(t) + + def commit_task_res(self, task, res, status=STATUS_DONE): + """ + Commit the result to task['res']. + + Args: + task ([type]): [description] + res (object): the result you want to save + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE. + """ + # A workaround to use the class attribute. + if status is None: + status = TaskManager.STATUS_DONE + self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) + + def return_task(self, task, status=STATUS_WAITING): + """ + Return a task to status. Alway using in error handling. + + Args: + task ([type]): [description] + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING. + """ + if status is None: + status = TaskManager.STATUS_WAITING + update_dict = {"$set": {"status": status}} + self.task_pool.update_one({"_id": task["_id"]}, update_dict) + + def remove(self, query={}): + """ + Remove the task using query + + Parameters + ---------- + query: dict + the dict of query + + """ + query = query.copy() + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + self.task_pool.delete_many(query) + + def task_stat(self, query={}) -> dict: + """ + Count the tasks in every status. + + Args: + query (dict, optional): the query dict. Defaults to {}. + + Returns: + dict + """ + query = query.copy() + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + tasks = self.query(query=query, decode=False) + status_stat = {} + for t in tasks: + status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1 + return status_stat + + def reset_waiting(self, query={}): + """ + Reset all running task into waiting status. Can be used when some running task exit unexpected. + + Args: + query (dict, optional): the query dict. Defaults to {}. + """ + query = query.copy() + # default query + if "status" not in query: + query["status"] = self.STATUS_RUNNING + return self.reset_status(query=query, status=self.STATUS_WAITING) + + def reset_status(self, query, status): + query = query.copy() + if "_id" in query: + query["_id"] = ObjectId(query["_id"]) + print(self.task_pool.update_many(query, {"$set": {"status": status}})) + + def prioritize(self, task, priority: int): + """ + Set priority for task + + Parameters + ---------- + task : dict + The task query from the database + priority : int + the target priority + """ + update_dict = {"$set": {"priority": priority}} + self.task_pool.update_one({"_id": task["_id"]}, update_dict) + + def _get_undone_n(self, task_stat): + return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0) + + def _get_total(self, task_stat): + return sum(task_stat.values()) + + def wait(self, query={}): + task_stat = self.task_stat(query) + total = self._get_total(task_stat) + last_undone_n = self._get_undone_n(task_stat) + with tqdm(total=total, initial=total - last_undone_n) as pbar: + while True: + time.sleep(10) + undone_n = self._get_undone_n(self.task_stat(query)) + pbar.update(last_undone_n - undone_n) + last_undone_n = undone_n + if undone_n == 0: + break + + def __str__(self): + return f"TaskManager({self.task_pool})" + + +def run_task( + task_func: Callable, + task_pool: str, + query: dict = {}, + force_release: bool = False, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, + **kwargs, +): + """ + While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool + + After running this method, here are 4 situations (before_status -> after_status): + + STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param + + STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param + + STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param + + STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param + + Parameters + ---------- + task_func : Callable + def (task_def, **kwargs) -> + the function to run the task + task_pool : str + the name of the task pool (Collection in MongoDB) + query: dict + will use this dict to query task_pool when fetching task + force_release : bool + will the program force to release the resource + before_status : str: + the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status : str: + the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs + the params for `task_func` + """ + tm = TaskManager(task_pool) + + ever_run = False + + while True: + with tm.safe_fetch_task(status=before_status, query=query) as task: + if task is None: + break + get_module_logger("run_task").info(task["def"]) + # when fetching `WAITING` task, use task["def"] to train + if before_status == TaskManager.STATUS_WAITING: + param = task["def"] + # when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"] + elif before_status == TaskManager.STATUS_PART_DONE: + param = task["res"] + else: + raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!") + if force_release: + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + res = executor.submit(task_func, param, **kwargs).result() + else: + res = task_func(param, **kwargs) + tm.commit_task_res(task, res, status=after_status) + ever_run = True + + return ever_run + + +if __name__ == "__main__": + # This is for using it in cmd + # E.g. : `python -m qlib.workflow.task.manage list` + auto_init() + fire.Fire(TaskManager) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py new file mode 100644 index 0000000000..174b4b9bfc --- /dev/null +++ b/qlib/workflow/task/utils.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Some tools for task management. +""" + +import bisect +import pandas as pd +from qlib.data import D +from qlib.workflow import R +from qlib.config import C +from qlib.log import get_module_logger +from pymongo import MongoClient +from pymongo.database import Database +from typing import Union + + +def get_mongodb() -> Database: + + """ + Get database in MongoDB, which means you need to declare the address and the name of a database at first. + + For example: + + Using qlib.init(): + + mongo_conf = { + "task_url": task_url, # your MongoDB url + "task_db_name": task_db_name, # database name + } + qlib.init(..., mongo=mongo_conf) + + After qlib.init(): + + C["mongo"] = { + "task_url" : "mongodb://localhost:27017/", + "task_db_name" : "rolling_db" + } + + Returns: + Database: the Database instance + """ + try: + cfg = C["mongo"] + except KeyError: + get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager") + raise + + client = MongoClient(cfg["task_url"]) + return client.get_database(name=cfg["task_db_name"]) + + +def list_recorders(experiment, rec_filter_func=None): + """ + List all recorders which can pass the filter in an experiment. + + Args: + experiment (str or Experiment): the name of an Experiment or an instance + rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None. + + Returns: + dict: a dict {rid: recorder} after filtering. + """ + if isinstance(experiment, str): + experiment = R.get_exp(experiment_name=experiment) + recs = experiment.list_recorders() + recs_flt = {} + for rid, rec in recs.items(): + if rec_filter_func is None or rec_filter_func(rec): + recs_flt[rid] = rec + + return recs_flt + + +class TimeAdjuster: + """ + Find appropriate date and adjust date. + """ + + def __init__(self, future=True, end_time=None): + self._future = future + self.cals = D.calendar(future=future, end_time=end_time) + + def set_end_time(self, end_time=None): + """ + Set end time. None for use calendar's end time. + + Args: + end_time + """ + self.cals = D.calendar(future=self._future, end_time=end_time) + + def get(self, idx: int): + """ + Get datetime by index. + + Parameters + ---------- + idx : int + index of the calendar + """ + if idx >= len(self.cals): + return None + return self.cals[idx] + + def max(self) -> pd.Timestamp: + """ + Return the max calendar datetime + """ + return max(self.cals) + + def align_idx(self, time_point, tp_type="start") -> int: + """ + Align the index of time_point in the calendar. + + Parameters + ---------- + time_point + tp_type : str + + Returns + ------- + index : int + """ + time_point = pd.Timestamp(time_point) + if tp_type == "start": + idx = bisect.bisect_left(self.cals, time_point) + elif tp_type == "end": + idx = bisect.bisect_right(self.cals, time_point) - 1 + else: + raise NotImplementedError(f"This type of input is not supported") + return idx + + def cal_interval(self, time_point_A, time_point_B) -> int: + """ + Calculate the trading day interval (time_point_A - time_point_B) + + Args: + time_point_A : time_point_A + time_point_B : time_point_B (is the past of time_point_A) + + Returns: + int: the interval between A and B + """ + return self.align_idx(time_point_A) - self.align_idx(time_point_B) + + def align_time(self, time_point, tp_type="start") -> pd.Timestamp: + """ + Align time_point to trade date of calendar + + Args: + time_point + Time point + tp_type : str + time point type (`"start"`, `"end"`) + + Returns: + pd.Timestamp + """ + return self.cals[self.align_idx(time_point, tp_type=tp_type)] + + def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]: + """ + Align the given date to the trade date + + for example: + + .. code-block:: python + + input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')} + + output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), + 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), + 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))} + + Parameters + ---------- + segment + + Returns + ------- + Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date. + """ + if isinstance(segment, dict): + return {k: self.align_seg(seg) for k, seg in segment.items()} + elif isinstance(segment, tuple) or isinstance(segment, list): + return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end") + else: + raise NotImplementedError(f"This type of input is not supported") + + def truncate(self, segment: tuple, test_start, days: int) -> tuple: + """ + Truncate the segment based on the test_start date + + Parameters + ---------- + segment : tuple + time segment + test_start + days : int + The trading days to be truncated + the data in this segment may need 'days' data + + Returns + --------- + tuple: new segment + """ + test_idx = self.align_idx(test_start) + if isinstance(segment, tuple): + new_seg = [] + for time_point in segment: + tp_idx = min(self.align_idx(time_point), test_idx - days) + assert tp_idx > 0 + new_seg.append(self.get(tp_idx)) + return tuple(new_seg) + else: + raise NotImplementedError(f"This type of input is not supported") + + SHIFT_SD = "sliding" + SHIFT_EX = "expanding" + + def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: + """ + Shift the datatime of segment + + Parameters + ---------- + seg : + datetime segment + step : int + rolling step + rtype : str + rolling type ("sliding" or "expanding") + + Returns + -------- + tuple: new segment + + Raises + ------ + KeyError: + shift will raise error if the index(both start and end) is out of self.cal + """ + if isinstance(seg, tuple): + start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end") + if rtype == self.SHIFT_SD: + start_idx += step + end_idx += step + elif rtype == self.SHIFT_EX: + end_idx += step + else: + raise NotImplementedError(f"This type of input is not supported") + if start_idx > len(self.cals): + raise KeyError("The segment is out of valid calendar") + return self.get(start_idx), self.get(end_idx) + else: + raise NotImplementedError(f"This type of input is not supported") diff --git a/setup.py b/setup.py index 747d885f4f..92c9ccc0cc 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,9 @@ "tornado", "joblib>=0.17.0", "ruamel.yaml>=0.16.12", + "pymongo==3.7.2", # For task management "scikit-learn>=0.22", + "dill", ] # Numpy include