Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine RL todos #1332

Merged
merged 12 commits into from
Nov 10, 2022
11 changes: 5 additions & 6 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import copy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
from typing import Dict, TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union

import pandas as pd

Expand Down Expand Up @@ -223,7 +223,7 @@ def backtest(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[Dict[str, Tuple[pd.DataFrame, dict]], Dict[str, Tuple[pd.DataFrame, Indicator]]]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution

Expand Down Expand Up @@ -256,9 +256,9 @@ def backtest(

Returns
-------
portfolio_metrics_dict: Dict[PortfolioMetrics]
portfolio_dict: Dict[str, Tuple[pd.DataFrame, dict]]
it records the trading portfolio_metrics information
indicator_dict: Dict[Indicator]
indicator_dict: Dict[str, Tuple[pd.DataFrame, Indicator]]
it computes the trading indicator
It is organized in a dict format

Expand All @@ -273,8 +273,7 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return portfolio_metrics, indicator
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)


def collect_data(
Expand Down
42 changes: 23 additions & 19 deletions qlib/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast

import pandas as pd

from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.report import Indicator, PortfolioMetrics
from qlib.backtest.report import Indicator

if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
Expand All @@ -24,25 +24,26 @@ def backtest_loop(
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[Dict[str, Tuple[pd.DataFrame, dict]], Dict[str, Tuple[pd.DataFrame, Indicator]]]:
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution

please refer to the docs of `collect_data_loop`

Returns
-------
portfolio_metrics: PortfolioMetrics
portfolio_dict: Dict[str, Tuple[pd.DataFrame, dict]]
it records the trading portfolio_metrics information
indicator: Indicator
indicator_dict: Dict[str, Tuple[pd.DataFrame, Indicator]]
it computes the trading indicator
"""
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass

portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
portfolio_dict = cast(Dict[str, Tuple[pd.DataFrame, dict]], return_value.get("portfolio_dict"))
indicator_dict = cast(Dict[str, Tuple[pd.DataFrame, Indicator]], return_value.get("indicator_dict"))

return portfolio_dict, indicator_dict


def collect_data_loop(
Expand Down Expand Up @@ -89,14 +90,17 @@ def collect_data_loop(

if return_value is not None:
all_executors = trade_executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})

portfolio_dict: Dict[str, Tuple[pd.DataFrame, dict]] = {}
indicator_dict: Dict[str, Tuple[pd.DataFrame, Indicator]] = {}

for executor in all_executors:
key = "{}{}".format(*Freq.parse(executor.time_per_step))
if executor.trade_account.is_port_metr_enabled():
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()

indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
indicator_obj = executor.trade_account.get_trade_indicator()
indicator_dict[key] = (indicator_df, indicator_obj)

return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})
75 changes: 31 additions & 44 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,21 @@
import pickle
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union, cast

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from qlib.typehint import Literal
from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest import Indicator, collect_data_loop, get_strategy_executor
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.backtest.executor import SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper


def _get_multi_level_executor_config(
Expand Down Expand Up @@ -61,15 +59,6 @@ def _get_multi_level_executor_config(
return executor_config


def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
if isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
env = CollectDataEnvWrapper()
env.reset()
executor.inner_strategy.set_env(env)
_set_env_for_all_strategy(executor.inner_executor)


def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
record_list = []
for time, value_dict in indicator.items():
Expand All @@ -94,9 +83,10 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
return records


# TODO: there should be richer annotation for the input (e.g. report) and the returned report
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
def _generate_report(
decisions: List[BaseTradeDecision],
report_indicators: List[Dict[str, Tuple[pd.DataFrame, Indicator]]],
) -> Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Generate backtest reports

Parameters
Expand All @@ -109,28 +99,25 @@ def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List
-------

"""
indicator_dict = defaultdict(list)
indicator_his = defaultdict(list)
indicator_dict: Dict[str, List[pd.DataFrame]] = defaultdict(list)
indicator_his: Dict[str, List[dict]] = defaultdict(list)

for report_indicator in report_indicators:
for key, value in report_indicator.items():
if key.endswith("_obj"):
indicator_his[key].append(value.order_indicator_his)
else:
indicator_dict[key].append(value)
for key, (indicator_df, indicator_obj) in report_indicator.items():
indicator_dict[key].append(indicator_df)
indicator_his[key].append(indicator_obj.order_indicator_his)

report = {}
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
for key in ["1min", "5min", "30min", "1day"]:
if key not in indicator_dict:
continue

report[key] = pd.concat(indicator_dict[key])
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])

for key in indicator_dict:
cur_dict = pd.concat(indicator_dict[key])
cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]])
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
if len(cur_details) > 0:
cur_details.pop("freq")
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
cur_his = cur_his.join(cur_details, how="outer")

report[key] = (cur_dict, cur_his)

return report

Expand Down Expand Up @@ -209,25 +196,25 @@ def single_with_simulator(
exchange_config=exchange_config,
qlib_config=None,
cash_limit=None,
backtest_mode=True,
)

reports.append(simulator.report_dict)
decisions += simulator.decisions

indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator)
indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()

if generate_report:
report = _generate_report(decisions, [report["indicator"] for report in reports])
_report = _generate_report(decisions, [report["indicator"] for report in reports])

if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: report}
report = {day: _report}

return records, report
else:
Expand Down Expand Up @@ -312,22 +299,22 @@ def single_with_collect_data_loop(
exchange_kwargs=exchange_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
_set_env_for_all_strategy(executor=executor)

report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))

records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
indicator_dict = cast(Dict[str, Tuple[pd.DataFrame, Indicator]], report_dict.get("indicator_dict"))
records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his)
assert records is None or not np.isnan(records["ffr"]).any()

if generate_report:
report = _generate_report(decisions, [report_dict["indicator"]])
_report = _generate_report(decisions, [indicator_dict])
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
report = {stock_id: _report}
else:
day = orders.iloc[0].datetime
report = {day: report}
report = {day: _report}
return records, report
else:
return records
Expand All @@ -337,7 +324,7 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram
order_df = read_order_file(backtest_config["order_file"])

cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config["exchange"].pop("generate_report")
generate_report = backtest_config.pop("generate_report")

stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()
Expand Down
8 changes: 5 additions & 3 deletions qlib/rl/contrib/naive_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import yaml


DELETE_KEY = "_delete_"


def merge_a_into_b(a: dict, b: dict) -> dict:
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b:
v.pop("_delete_", False) # TODO: make this more elegant
v.pop(DELETE_KEY, False)
b[k] = merge_a_into_b(v, b[k])
else:
b[k] = v
Expand Down Expand Up @@ -86,7 +89,6 @@ def get_backtest_config_fromfile(path: str) -> dict:
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
"generate_report": False,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
Expand All @@ -97,7 +99,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
# "runtime": {},
"generate_report": False,
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

Expand Down
23 changes: 12 additions & 11 deletions qlib/rl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@

from __future__ import annotations

from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
from typing import Any, Generic, TypeVar

import gym
import numpy as np
from gym import spaces

from qlib.typehint import final
from .simulator import ActType, StateType

if TYPE_CHECKING:
from .utils.env_wrapper import BaseEnvWrapper

import gym
from gym import spaces

ObsType = TypeVar("ObsType")
PolicyActType = TypeVar("PolicyActType")

Expand All @@ -35,12 +31,19 @@ class Interpreter:
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
"""

def __init__(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird to maintain these states in the Interpreter,
we may have a discussion later about where there is a better design.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is more reasonable to get current step from state.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I think it's better to have them in EnvWrapperStatus because it's designed for this purpose.
But it would require adding DummyEnv back.

self.cur_step = 0

def reset(self) -> None:
self.cur_step = 0

def step(self) -> None:
self.cur_step += 1


class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""

env: Optional[BaseEnvWrapper] = None

@property
def observation_space(self) -> gym.Space:
raise NotImplementedError()
Expand Down Expand Up @@ -73,8 +76,6 @@ def interpret(self, simulator_state: StateType) -> ObsType:
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""

env: Optional[BaseEnvWrapper] = None

@property
def action_space(self) -> gym.Space:
raise NotImplementedError()
Expand Down
Loading