diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index aca45f4e1e..ec0725230a 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -10,7 +10,6 @@ import pandas as pd from .account import Account -from .report import Indicator, PortfolioMetrics if TYPE_CHECKING: from ..strategy.base import BaseStrategy @@ -20,7 +19,7 @@ from ..config import C from ..log import get_module_logger from ..utils import init_instance_by_config -from .backtest import backtest_loop, collect_data_loop +from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop from .decision import Order from .exchange import Exchange from .utils import CommonInfrastructure @@ -223,7 +222,7 @@ def backtest( account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", -) -> Tuple[PortfolioMetrics, Indicator]: +) -> Tuple[PORT_METRIC, INDICATOR_METRIC]: """initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution @@ -256,9 +255,9 @@ def backtest( Returns ------- - portfolio_metrics_dict: Dict[PortfolioMetrics] + portfolio_dict: PORT_METRIC it records the trading portfolio_metrics information - indicator_dict: Dict[Indicator] + indicator_dict: INDICATOR_METRIC it computes the trading indicator It is organized in a dict format @@ -273,8 +272,7 @@ def backtest( exchange_kwargs, pos_type=pos_type, ) - portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor) - return portfolio_metrics, indicator + return backtest_loop(start_time, end_time, trade_strategy, trade_executor) def collect_data( diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index f79622bff6..cf0a3a5786 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -3,12 +3,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast +from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast import pandas as pd from qlib.backtest.decision import BaseTradeDecision -from qlib.backtest.report import Indicator, PortfolioMetrics +from qlib.backtest.report import Indicator if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy @@ -19,30 +19,35 @@ from ..utils.time import Freq +PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]] +INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]] + + def backtest_loop( start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], trade_strategy: BaseStrategy, trade_executor: BaseExecutor, -) -> Tuple[PortfolioMetrics, Indicator]: +) -> Tuple[PORT_METRIC, INDICATOR_METRIC]: """backtest function for the interaction of the outermost strategy and executor in the nested decision execution please refer to the docs of `collect_data_loop` Returns ------- - portfolio_metrics: PortfolioMetrics + portfolio_dict: PORT_METRIC it records the trading portfolio_metrics information - indicator: Indicator + indicator_dict: INDICATOR_METRIC it computes the trading indicator """ return_value: dict = {} for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): pass - portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics")) - indicator = cast(Indicator, return_value.get("indicator")) - return portfolio_metrics, indicator + portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict")) + indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict")) + + return portfolio_dict, indicator_dict def collect_data_loop( @@ -89,14 +94,17 @@ def collect_data_loop( if return_value is not None: all_executors = trade_executor.get_all_executors() - all_portfolio_metrics = { - "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics() - for _executor in all_executors - if _executor.trade_account.is_port_metr_enabled() - } - all_indicators = {} - for _executor in all_executors: - key = "{}{}".format(*Freq.parse(_executor.time_per_step)) - all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() - all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator() - return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators}) + + portfolio_dict: PORT_METRIC = {} + indicator_dict: INDICATOR_METRIC = {} + + for executor in all_executors: + key = "{}{}".format(*Freq.parse(executor.time_per_step)) + if executor.trade_account.is_port_metr_enabled(): + portfolio_dict[key] = executor.trade_account.get_portfolio_metrics() + + indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + indicator_obj = executor.trade_account.get_trade_indicator() + indicator_dict[key] = (indicator_df, indicator_obj) + + return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict}) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 16cd8815f9..cc760be44d 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -26,6 +26,15 @@ class Exchange: + # `quote_df` is a pd.DataFrame class that contains basic information for backtesting + # After some processing, the data will later be maintained by `quote_cls` object for faster data retriving. + # Some conventions for `quote_df` + # - $close is for calculating the total value at end of each day. + # - if $close is None, the stock on that day is reguarded as suspended. + # - $factor is for rounding to the trading unit; + # - if any $factor is missing when $close exists, trading unit rounding will be disabled + quote_df: pd.DataFrame + def __init__( self, freq: str = "day", @@ -159,6 +168,7 @@ def __init__( self.codes = codes # Necessary fields # $close is for calculating the total value at end of each day. + # - if $close is None, the stock on that day is reguarded as suspended. # $factor is for rounding to the trading unit # $change is for calculating the limit of the stock @@ -199,7 +209,7 @@ def get_quote_from_qlib(self) -> None: self.end_time, freq=self.freq, disk_cache=True, - ).dropna(subset=["$close"]) + ) self.quote_df.columns = self.all_fields # check buy_price data and sell_price data @@ -209,7 +219,7 @@ def get_quote_from_qlib(self) -> None: self.logger.warning("{} field data contains nan.".format(pstr)) # update trade_w_adj_price - if self.quote_df["$factor"].isna().any(): + if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any(): # The 'factor.day.bin' file not exists, and `factor` field contains `nan` # Use adjusted price self.trade_w_adj_price = True @@ -245,9 +255,9 @@ def get_quote_from_qlib(self) -> None: assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"} self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0) - LT_TP_EXP = "(exp)" # Tuple[str, str] - LT_FLT = "float" # float - LT_NONE = "none" # none + LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression. + LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold` + LT_NONE = "none" # none: there is no trading limitation def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str: """get limit type""" @@ -261,20 +271,25 @@ def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str: raise NotImplementedError(f"This type of `limit_threshold` is not supported") def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None: + # $close is may contains NaN, the nan indicates that the stock is not tradable at that timestamp + suspended = self.quote_df["$close"].isna() # check limit_threshold limit_type = self._get_limit_type(limit_threshold) if limit_type == self.LT_NONE: - self.quote_df["limit_buy"] = False - self.quote_df["limit_sell"] = False + self.quote_df["limit_buy"] = suspended + self.quote_df["limit_sell"] = suspended elif limit_type == self.LT_TP_EXP: # set limit limit_threshold = cast(tuple, limit_threshold) - self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]] - self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]] + # astype bool is necessary, because quote_df is an expression and could be float + self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended + self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended elif limit_type == self.LT_FLT: limit_threshold = cast(float, limit_threshold) - self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) - self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130 + self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended + self.quote_df["limit_sell"] = ( + self.quote_df["$change"].le(-limit_threshold) | suspended + ) # pylint: disable=E1130 @staticmethod def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]: @@ -338,8 +353,18 @@ def check_stock_limit( - if direction is None, check if tradable for buying and selling. - if direction == Order.BUY, check the if tradable for buying - if direction == Order.SELL, check the sell limit for selling. + + Returns + ------- + True: the trading of the stock is limted (maybe hit the highest/lowest price), hence the stock is not tradable + False: the trading of the stock is not limited, hence the stock may be tradable """ + # NOTE: + # **all** is used when checking limitation. + # For example, the stock trading is limited in a day if every miniute is limited in a day if every miniute is limited. if direction is None: + # The trading limitation is related to the trading direction + # if the direction is not provided, then any limitation from buy or sell will result in trading limitation buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") return bool(buy_limit or sell_limit) @@ -356,10 +381,24 @@ def check_stock_suspended( start_time: pd.Timestamp, end_time: pd.Timestamp, ) -> bool: + """if stock is suspended(hence not tradable), True will be returned""" # is suspended if stock_id in self.quote.get_all_stock(): - return self.quote.get_data(stock_id, start_time, end_time, "$close") is None + # suspended stocks are represented by None $close stock + # The $close may contains NaN, + close = self.quote.get_data(stock_id, start_time, end_time, "$close") + if close is None: + # if no close record exists + return True + elif isinstance(close, IndexData): + # **any** non-NaN $close represents trading opportunity may exists + # if all returned is nan, then the stock is suspended + return cast(bool, cast(IndexData, close).isna().all()) + else: + # it is single value, make sure is is not None + return np.isnan(close) else: + # if the stock is not in the stock list, then it is not tradable and regarded as suspended return True def is_stock_tradable( diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 695c13d2ed..4d1eae46db 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -8,23 +8,22 @@ import pickle from collections import defaultdict from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, cast import numpy as np import pandas as pd import torch from joblib import Parallel, delayed -from qlib.typehint import Literal -from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest import INDICATOR_METRIC, collect_data_loop, get_strategy_executor from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime -from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor +from qlib.backtest.executor import SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution -from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper +from qlib.typehint import Literal def _get_multi_level_executor_config( @@ -61,15 +60,6 @@ def _get_multi_level_executor_config( return executor_config -def _set_env_for_all_strategy(executor: BaseExecutor) -> None: - if isinstance(executor, NestedExecutor): - if hasattr(executor.inner_strategy, "set_env"): - env = CollectDataEnvWrapper() - env.reset() - executor.inner_strategy.set_env(env) - _set_env_for_all_strategy(executor.inner_executor) - - def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: record_list = [] for time, value_dict in indicator.items(): @@ -94,9 +84,10 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: return records -# TODO: there should be richer annotation for the input (e.g. report) and the returned report -# TODO: For example, @ dataclass with typed fields and detailed docstrings. -def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict: +def _generate_report( + decisions: List[BaseTradeDecision], + report_indicators: List[INDICATOR_METRIC], +) -> Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]: """Generate backtest reports Parameters @@ -109,28 +100,25 @@ def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List ------- """ - indicator_dict = defaultdict(list) - indicator_his = defaultdict(list) + indicator_dict: Dict[str, List[pd.DataFrame]] = defaultdict(list) + indicator_his: Dict[str, List[dict]] = defaultdict(list) + for report_indicator in report_indicators: - for key, value in report_indicator.items(): - if key.endswith("_obj"): - indicator_his[key].append(value.order_indicator_his) - else: - indicator_dict[key].append(value) + for key, (indicator_df, indicator_obj) in report_indicator.items(): + indicator_dict[key].append(indicator_df) + indicator_his[key].append(indicator_obj.order_indicator_his) report = {} decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")]) - for key in ["1min", "5min", "30min", "1day"]: - if key not in indicator_dict: - continue - - report[key] = pd.concat(indicator_dict[key]) - report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]]) - + for key in indicator_dict: + cur_dict = pd.concat(indicator_dict[key]) + cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]]) cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"]) if len(cur_details) > 0: cur_details.pop("freq") - report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") + cur_his = cur_his.join(cur_details, how="outer") + + report[key] = (cur_dict, cur_his) return report @@ -209,25 +197,25 @@ def single_with_simulator( exchange_config=exchange_config, qlib_config=None, cash_limit=None, - backtest_mode=True, ) reports.append(simulator.report_dict) decisions += simulator.decisions - indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()} - records = _convert_indicator_to_dataframe(indicator) + indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports] + indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()} + records = _convert_indicator_to_dataframe(indicator_info) assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, [report["indicator"] for report in reports]) + _report = _generate_report(decisions, [report["indicator"] for report in reports]) if split == "stock": stock_id = orders.iloc[0].instrument - report = {stock_id: report} + report = {stock_id: _report} else: day = orders.iloc[0].datetime - report = {day: report} + report = {day: _report} return records, report else: @@ -312,22 +300,22 @@ def single_with_collect_data_loop( exchange_kwargs=exchange_config, pos_type="Position" if cash_limit is not None else "InfPosition", ) - _set_env_for_all_strategy(executor=executor) report_dict: dict = {} decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict)) - records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his) + indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict")) + records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his) assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - report = _generate_report(decisions, [report_dict["indicator"]]) + _report = _generate_report(decisions, [indicator_dict]) if split == "stock": stock_id = orders.iloc[0].instrument - report = {stock_id: report} + report = {stock_id: _report} else: day = orders.iloc[0].datetime - report = {day: report} + report = {day: _report} return records, report else: return records @@ -337,7 +325,7 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram order_df = read_order_file(backtest_config["order_file"]) cash_limit = backtest_config["exchange"].pop("cash_limit") - generate_report = backtest_config["exchange"].pop("generate_report") + generate_report = backtest_config.pop("generate_report") stock_pool = order_df["instrument"].unique().tolist() stock_pool.sort() @@ -382,9 +370,19 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") + parser.add_argument( + "--n_jobs", + type=int, + required=False, + help="The number of jobs for running backtest parallely(1 for single process)", + ) args = parser.parse_args() + config = get_backtest_config_fromfile(args.config_path) + if args.n_jobs is not None: + config["concurrency"] = args.n_jobs + backtest( - backtest_config=get_backtest_config_fromfile(args.config_path), + backtest_config=config, with_simulator=args.use_simulator, ) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 3f3d2eeadc..ab5e953596 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -11,11 +11,14 @@ import yaml +DELETE_KEY = "_delete_" + + def merge_a_into_b(a: dict, b: dict) -> dict: b = b.copy() for k, v in a.items(): if isinstance(v, dict) and k in b: - v.pop("_delete_", False) # TODO: make this more elegant + v.pop(DELETE_KEY, False) b[k] = merge_a_into_b(v, b[k]) else: b[k] = v @@ -86,7 +89,6 @@ def get_backtest_config_fromfile(path: str) -> dict: "min_cost": 5.0, "trade_unit": 100.0, "cash_limit": None, - "generate_report": False, } backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default) backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"]) @@ -97,7 +99,7 @@ def get_backtest_config_fromfile(path: str) -> dict: "concurrency": -1, "multiplier": 1.0, "output_dir": "outputs/", - # "runtime": {}, + "generate_report": False, } backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index 9417534f86..f09d909bc8 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -13,7 +13,6 @@ from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from .integration import fetch_features -from ...data import D class IntradayBacktestData(BaseIntradayBacktestData): @@ -81,17 +80,7 @@ def load_backtest_data( trade_exchange: Exchange, trade_range: TradeRange, ) -> IntradayBacktestData: - # TODO: making exchange return data without missing will make it more elegant. Fix this in the future. - tmp_data = D.features( - trade_exchange.codes, - trade_exchange.all_fields, - trade_exchange.start_time, - trade_exchange.end_time, - freq=trade_exchange.freq, - disk_cache=True, - ) - - ticks_index = pd.DatetimeIndex(tmp_data.reset_index()["datetime"]) + ticks_index = pd.DatetimeIndex(trade_exchange.quote_df.reset_index()["datetime"]) ticks_index = ticks_index[order.start_time <= ticks_index] ticks_index = ticks_index[ticks_index <= order.end_time] diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index d2d81f81cd..5c9cc26c4e 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,19 +3,15 @@ from __future__ import annotations -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, TypeVar +import gym import numpy as np +from gym import spaces from qlib.typehint import final from .simulator import ActType, StateType -if TYPE_CHECKING: - from .utils.env_wrapper import BaseEnvWrapper - -import gym -from gym import spaces - ObsType = TypeVar("ObsType") PolicyActType = TypeVar("PolicyActType") @@ -39,8 +35,6 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: Optional[BaseEnvWrapper] = None - @property def observation_space(self) -> gym.Space: raise NotImplementedError() @@ -73,8 +67,6 @@ def interpret(self, simulator_state: StateType) -> ObsType: class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: Optional[BaseEnvWrapper] = None - @property def action_space(self) -> gym.Space: raise NotImplementedError() diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 0b89977491..0d45624bda 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -69,8 +69,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): Provider of the processed data. """ - # TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case. - # TODO: So it should be redesigned after the data interface is well-designed. def __init__( self, max_step: int, @@ -78,6 +76,8 @@ def __init__( data_dim: int, processed_data_provider: dict | ProcessedDataProvider, ) -> None: + super().__init__() + self.max_step = max_step self.data_ticks = data_ticks self.data_dim = data_dim @@ -87,10 +87,6 @@ def __init__( ) def interpret(self, state: SAOEState) -> FullHistoryObs: - # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running - # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant - # way to decompose interpreter and EnvWrapper in the future. - processed = self.processed_data_provider.get_data( stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()), @@ -102,8 +98,6 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: position_history[0] = state.order.amount position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() - assert self.env is not None - # The min, slice here are to make sure that indices fit into the range, # even after the final step of the simulator (in the done step), # to make network in policy happy. @@ -115,7 +109,7 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: "data_processed_prev": np.array(processed.yesterday), "acquiring": _to_int32(state.order.direction == state.order.BUY), "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), - "cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)), + "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), "num_step": _to_int32(self.max_step), "target": _to_float32(state.order.amount), "position": _to_float32(state.position), @@ -163,6 +157,8 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]): """ def __init__(self, max_step: int) -> None: + super().__init__() + self.max_step = max_step @property @@ -177,15 +173,10 @@ def observation_space(self) -> spaces.Dict: return spaces.Dict(space) def interpret(self, state: SAOEState) -> CurrentStateObs: - # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running - # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant - # way to decompose interpreter and EnvWrapper in the future. - - assert self.env is not None - assert self.env.status["cur_step"] <= self.max_step + assert state.cur_step <= self.max_step obs = CurrentStateObs( acquiring=state.order.direction == state.order.BUY, - cur_step=self.env.status["cur_step"], + cur_step=state.cur_step, num_step=self.max_step, target=state.order.amount, position=state.position, @@ -208,6 +199,8 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): """ def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None: + super().__init__() + if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values @@ -218,13 +211,8 @@ def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.action_values)) def interpret(self, state: SAOEState, action: int) -> float: - # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running - # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant - # way to decompose interpreter and EnvWrapper in the future. - assert 0 <= action < len(self.action_values) - assert self.env is not None - if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1: + if self.max_step is not None and state.cur_step >= self.max_step - 1: return state.position else: return min(state.position, state.order.amount * self.action_values[action]) @@ -244,13 +232,8 @@ def action_space(self) -> spaces.Box: return spaces.Box(0, np.inf, shape=(), dtype=np.float32) def interpret(self, state: SAOEState, action: float) -> float: - # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running - # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant - # way to decompose interpreter and EnvWrapper in the future. - - assert self.env is not None estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) - twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"]) + twap_volume = state.position / (estimated_total_steps - state.cur_step) return min(state.position, twap_volume * action) diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index 7f7a98e9a7..598e6b589a 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast +from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast import gym import numpy as np @@ -14,6 +14,8 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.policy import BasePolicy, PPOPolicy +from qlib.rl.trainer.trainer import Trainer + __all__ = ["AllOne", "PPO"] @@ -148,7 +150,7 @@ def __init__( action_space=action_space, ) if weight_file is not None: - load_weight(self, weight_file) + set_weight(self, Trainer.get_policy_state_dict(weight_file)) # utilities: these should be put in a separate (common) file. # @@ -160,15 +162,7 @@ def auto_device(module: nn.Module) -> torch.device: return torch.device("cpu") # fallback to cpu -def load_weight(policy: nn.Module, path: Path) -> None: - assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight." - loaded_weight = torch.load(path, map_location="cpu") - - # TODO: this should be handled by whoever calls load_weight. - # TODO: For example, when the outer class receives a weight, it should first unpack it, - # TODO: and send the corresponding part to individual component. - if "vessel" in loaded_weight: - loaded_weight = loaded_weight["vessel"]["policy"] +def set_weight(policy: nn.Module, loaded_weight: OrderedDict) -> None: try: policy.load_state_dict(loaded_weight) except RuntimeError: diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index c9702b1e48..610a0c0bd5 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -9,12 +9,11 @@ from qlib.backtest import collect_data_loop, get_strategy_executor from qlib.backtest.decision import BaseTradeDecision, Order, TradeRangeByTime -from qlib.backtest.executor import BaseExecutor, NestedExecutor +from qlib.backtest.executor import NestedExecutor from qlib.rl.data.integration import init_qlib from qlib.rl.simulator import Simulator -from .state import SAOEState, SAOEStateAdapter -from .strategy import SAOEStrategy -from ..utils.env_wrapper import CollectDataEnvWrapper +from .state import SAOEState +from .strategy import SAOEStateAdapter, SAOEStrategy class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): @@ -32,8 +31,6 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): Configuration used to initialize Qlib. If it is None, Qlib will not be initialized. cash_limit: Cash limit. - backtest_mode - Whether the simulator is under backtest mode. """ def __init__( @@ -43,7 +40,6 @@ def __init__( exchange_config: dict, qlib_config: dict = None, cash_limit: Optional[float] = None, - backtest_mode: bool = False, ) -> None: super().__init__(initial=order) @@ -59,7 +55,7 @@ def __init__( } self._collect_data_loop: Optional[Generator] = None - self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit, backtest_mode) + self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit) def reset( self, @@ -69,7 +65,6 @@ def reset( exchange_config: dict, qlib_config: dict = None, cash_limit: Optional[float] = None, - backtest_mode: bool = False, ) -> None: if qlib_config is not None: init_qlib(qlib_config, part="skip") @@ -98,16 +93,6 @@ def reset( ) assert isinstance(self._collect_data_loop, Generator) - # TODO: backtest_mode is not a necessary parameter if we carefully design it. - # TODO: It should disappear with CollectDataEnvWrapper in the future. - if backtest_mode: - executor: BaseExecutor = self._executor - while isinstance(executor, NestedExecutor): - if hasattr(executor.inner_strategy, "set_env"): - executor.inner_strategy.set_env(CollectDataEnvWrapper()) - executor = executor.inner_executor - - # Call `step()` with None action to initialize the internal generator. self.step(action=None) self._order = order diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 17efb4b093..9086e6047c 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -16,8 +16,6 @@ from .state import SAOEMetrics, SAOEState -# TODO: Integrating Qlib's native data with simulator_simple - __all__ = ["SingleAssetOrderExecutionSimple"] @@ -98,6 +96,7 @@ def __init__( self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time) self.cur_time = self.ticks_for_order[0] + self.cur_step = 0 # NOTE: astype(float) is necessary in some systems. # this will align the precision with `.to_numpy()` in `_split_exec_vol` self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean()) @@ -194,11 +193,13 @@ def step(self, amount: float) -> None: self.env.logger.add_any(key, value) self.cur_time = self._next_time() + self.cur_step += 1 def get_state(self) -> SAOEState: return SAOEState( order=self.order, cur_time=self.cur_time, + cur_step=self.cur_step, position=self.position, history_exec=self.history_exec, history_steps=self.history_steps, diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index f417173e52..315735eaf8 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -4,290 +4,15 @@ from __future__ import annotations import typing -from typing import cast, Callable, List, NamedTuple, Optional, Tuple +from typing import NamedTuple, Optional import numpy as np import pandas as pd -from qlib.backtest import Exchange, Order -from qlib.backtest.executor import BaseExecutor -from qlib.constant import EPS, ONE_MIN, REG_CN -from qlib.rl.order_execution.utils import dataframe_append, price_advantage +from qlib.backtest import Order from qlib.typehint import TypedDict -from qlib.utils.index_data import IndexData -from qlib.utils.time import get_day_min_idx_range if typing.TYPE_CHECKING: from qlib.rl.data.base import BaseIntradayBacktestData - from qlib.rl.data.native import IntradayBacktestData - - -def _get_all_timestamps( - start: pd.Timestamp, - end: pd.Timestamp, - granularity: pd.Timedelta = ONE_MIN, - include_end: bool = True, -) -> pd.DatetimeIndex: - ret = [] - while start <= end: - ret.append(start) - start += granularity - - if ret[-1] > end: - ret.pop() - if ret[-1] == end and not include_end: - ret.pop() - return pd.DatetimeIndex(ret) - - -def fill_missing_data( - original_data: np.ndarray, - total_time_list: List[pd.Timestamp], - found_time_list: List[pd.Timestamp], - fill_method: Callable = np.median, -) -> np.ndarray: - """Fill missing data. We need this function to deal with data that have missing values in some minutes. - - TODO: making exchange return data without missing will make it more elegant. Fix this in the future. - - Parameters - ---------- - original_data - Original data without missing values. - total_time_list - All timestamps that required. - found_time_list - Timestamps found in the original data. - fill_method - Method used to fill the missing data. - - Returns - ------- - The filled data. - """ - assert len(original_data) == len(found_time_list) - tmp = dict(zip(found_time_list, original_data)) - fill_val = fill_method(original_data) - return np.array([tmp.get(t, fill_val) for t in total_time_list]) - - -class SAOEStateAdapter: - """ - Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state - according to the execution results with additional information acquired from executors & exchange. For example, - it gets the dealt order amount from execution results, and get the corresponding market price / volume from - exchange. - - Example usage:: - - adapter = SAOEStateAdapter(...) - adapter.update(...) - state = adapter.saoe_state - """ - - def __init__( - self, - order: Order, - executor: BaseExecutor, - exchange: Exchange, - ticks_per_step: int, - backtest_data: IntradayBacktestData, - ) -> None: - self.position = order.amount - self.order = order - self.executor = executor - self.exchange = exchange - self.backtest_data = backtest_data - - self.twap_price = self.backtest_data.get_deal_price().mean() - - metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member - self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") - self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") - self.metrics: Optional[SAOEMetrics] = None - - self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time) - self.ticks_per_step = ticks_per_step - - def _next_time(self) -> pd.Timestamp: - current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time) - next_loc = current_loc + self.ticks_per_step - next_loc = next_loc - next_loc % self.ticks_per_step - if ( - next_loc < len(self.backtest_data.ticks_index) - and self.backtest_data.ticks_index[next_loc] < self.order.end_time - ): - return self.backtest_data.ticks_index[next_loc] - else: - return self.order.end_time - - def update( - self, - execute_result: list, - last_step_range: Tuple[int, int], - ) -> None: - last_step_size = last_step_range[1] - last_step_range[0] + 1 - start_time = self.backtest_data.ticks_index[last_step_range[0]] - end_time = self.backtest_data.ticks_index[last_step_range[1]] - - exec_vol = np.zeros(last_step_size) - for order, _, __, ___ in execute_result: - idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN) - exec_vol[idx - last_step_range[0]] = order.deal_amount - - if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: - assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" - exec_vol *= self.position / (exec_vol.sum()) - - market_volume = cast( - IndexData, - self.exchange.get_volume( - self.order.stock_id, - pd.Timestamp(start_time), - pd.Timestamp(end_time), - method=None, - ), - ) - market_price = cast( - IndexData, - self.exchange.get_deal_price( - self.order.stock_id, - pd.Timestamp(start_time), - pd.Timestamp(end_time), - method=None, - direction=self.order.direction, - ), - ) - found_time_list = [pd.Timestamp(e) for e in list(market_volume.index)] - total_time_list = _get_all_timestamps(start_time, end_time) - market_price = fill_missing_data(np.array(market_price).reshape(-1), total_time_list, found_time_list) - market_volume = fill_missing_data(np.array(market_volume).reshape(-1), total_time_list, found_time_list) - - assert market_price.shape == market_volume.shape == exec_vol.shape - - # Get data from the current level executor's indicator - current_trade_account = self.executor.trade_account - current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe() - self.history_exec = dataframe_append( - self.history_exec, - self._collect_multi_order_metric( - order=self.order, - datetime=_get_all_timestamps(start_time, end_time, include_end=True), - market_vol=market_volume, - market_price=market_price, - exec_vol=exec_vol, - pa=current_df.iloc[-1]["pa"], - ), - ) - - self.history_steps = dataframe_append( - self.history_steps, - [ - self._collect_single_order_metric( - self.order, - self.cur_time, - market_volume, - market_price, - exec_vol.sum(), - exec_vol, - ), - ], - ) - - # TODO: check whether we need this. Can we get this information from Account? - # Do this at the end - self.position -= exec_vol.sum() - - self.cur_time = self._next_time() - - def generate_metrics_after_done(self) -> None: - """Generate metrics once the upper level execution is done""" - - self.metrics = self._collect_single_order_metric( - self.order, - self.backtest_data.ticks_index[0], # start time - self.history_exec["market_volume"], - self.history_exec["market_price"], - self.history_steps["amount"].sum(), - self.history_exec["deal_amount"], - ) - - def _collect_multi_order_metric( - self, - order: Order, - datetime: pd.DatetimeIndex, - market_vol: np.ndarray, - market_price: np.ndarray, - exec_vol: np.ndarray, - pa: float, - ) -> SAOEMetrics: - return SAOEMetrics( - # It should have the same keys with SAOEMetrics, - # but the values do not necessarily have the annotated type. - # Some values could be vectorized (e.g., exec_vol). - stock_id=order.stock_id, - datetime=datetime, - direction=order.direction, - market_volume=market_vol, - market_price=market_price, - amount=exec_vol, - inner_amount=exec_vol, - deal_amount=exec_vol, - trade_price=market_price, - trade_value=market_price * exec_vol, - position=self.position - np.cumsum(exec_vol), - ffr=exec_vol / order.amount, - pa=pa, - ) - - def _collect_single_order_metric( - self, - order: Order, - datetime: pd.Timestamp, - market_vol: np.ndarray, - market_price: np.ndarray, - amount: float, # intended to trade such amount - exec_vol: np.ndarray, - ) -> SAOEMetrics: - assert len(market_vol) == len(market_price) == len(exec_vol) - - if np.abs(np.sum(exec_vol)) < EPS: - exec_avg_price = 0.0 - else: - exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan - if hasattr(exec_avg_price, "item"): # could be numpy scalar - exec_avg_price = exec_avg_price.item() # type: ignore - - exec_sum = exec_vol.sum() - return SAOEMetrics( - stock_id=order.stock_id, - datetime=datetime, - direction=order.direction, - market_volume=market_vol.sum(), - market_price=market_price.mean() if len(market_price) > 0 else np.nan, - amount=amount, - inner_amount=exec_sum, - deal_amount=exec_sum, # in this simulator, there's no other restrictions - trade_price=exec_avg_price, - trade_value=float(np.sum(market_price * exec_vol)), - position=self.position - exec_sum, - ffr=float(exec_sum / order.amount), - pa=price_advantage(exec_avg_price, self.twap_price, order.direction), - ) - - @property - def saoe_state(self) -> SAOEState: - return SAOEState( - order=self.order, - cur_time=self.cur_time, - position=self.position, - history_exec=self.history_exec, - history_steps=self.history_steps, - metrics=self.metrics, - backtest_data=self.backtest_data, - ticks_per_step=self.ticks_per_step, - ticks_index=self.backtest_data.ticks_index, - ticks_for_order=self.backtest_data.ticks_for_order, - ) class SAOEMetrics(TypedDict): @@ -302,7 +27,7 @@ class SAOEMetrics(TypedDict): stock_id: str """Stock ID of this record.""" - datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this + datetime: pd.Timestamp | pd.DatetimeIndex """Datetime of this record (this is index in the dataframe).""" direction: int """Direction of the order. 0 for sell, 1 for buy.""" @@ -349,6 +74,8 @@ class SAOEState(NamedTuple): """The order we are dealing with.""" cur_time: pd.Timestamp """Current time, e.g., 9:30.""" + cur_step: int + """Current step, e.g., 0.""" position: float """Current remaining volume to execute.""" history_exec: pd.DataFrame diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 663b8e8ff4..0102b9e57f 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -5,7 +5,7 @@ import collections from types import GeneratorType -from typing import Any, cast, Dict, Generator, List, Optional, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -15,14 +15,276 @@ from qlib.backtest import CommonInfrastructure, Order from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange -from qlib.backtest.utils import LevelInfrastructure -from qlib.constant import ONE_MIN -from qlib.rl.data.native import load_backtest_data +from qlib.backtest.exchange import Exchange +from qlib.backtest.executor import BaseExecutor +from qlib.backtest.utils import LevelInfrastructure, get_start_end_idx +from qlib.constant import EPS, ONE_MIN, REG_CN +from qlib.rl.data.native import IntradayBacktestData, load_backtest_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter -from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter -from qlib.rl.utils.env_wrapper import BaseEnvWrapper +from qlib.rl.order_execution.state import SAOEMetrics, SAOEState +from qlib.rl.order_execution.utils import dataframe_append, price_advantage from qlib.strategy.base import RLStrategy from qlib.utils import init_instance_by_config +from qlib.utils.index_data import IndexData +from qlib.utils.time import get_day_min_idx_range + + +def _get_all_timestamps( + start: pd.Timestamp, + end: pd.Timestamp, + granularity: pd.Timedelta = ONE_MIN, + include_end: bool = True, +) -> pd.DatetimeIndex: + ret = [] + while start <= end: + ret.append(start) + start += granularity + + if ret[-1] > end: + ret.pop() + if ret[-1] == end and not include_end: + ret.pop() + return pd.DatetimeIndex(ret) + + +def fill_missing_data( + original_data: np.ndarray, + fill_method: Callable = np.nanmedian, +) -> np.ndarray: + """Fill missing data. + + Parameters + ---------- + original_data + Original data without missing values. + fill_method + Method used to fill the missing data. + + Returns + ------- + The filled data. + """ + return np.nan_to_num(original_data, nan=fill_method(original_data)) + + +class SAOEStateAdapter: + """ + Maintain states of the environment. SAOEStateAdapter accepts execution results and update its internal state + according to the execution results with additional information acquired from executors & exchange. For example, + it gets the dealt order amount from execution results, and get the corresponding market price / volume from + exchange. + + Example usage:: + + adapter = SAOEStateAdapter(...) + adapter.update(...) + state = adapter.saoe_state + """ + + def __init__( + self, + order: Order, + trade_decision: BaseTradeDecision, + executor: BaseExecutor, + exchange: Exchange, + ticks_per_step: int, + backtest_data: IntradayBacktestData, + ) -> None: + self.position = order.amount + self.order = order + self.executor = executor + self.exchange = exchange + self.backtest_data = backtest_data + self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision) + + self.twap_price = self.backtest_data.get_deal_price().mean() + + metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") + self.metrics: Optional[SAOEMetrics] = None + + self.cur_time = max(backtest_data.ticks_for_order[0], order.start_time) + self.ticks_per_step = ticks_per_step + + def _next_time(self) -> pd.Timestamp: + current_loc = self.backtest_data.ticks_index.get_loc(self.cur_time) + next_loc = current_loc + self.ticks_per_step + next_loc = next_loc - next_loc % self.ticks_per_step + if ( + next_loc < len(self.backtest_data.ticks_index) + and self.backtest_data.ticks_index[next_loc] < self.order.end_time + ): + return self.backtest_data.ticks_index[next_loc] + else: + return self.order.end_time + + def update( + self, + execute_result: list, + last_step_range: Tuple[int, int], + ) -> None: + last_step_size = last_step_range[1] - last_step_range[0] + 1 + start_time = self.backtest_data.ticks_index[last_step_range[0]] + end_time = self.backtest_data.ticks_index[last_step_range[1]] + + exec_vol = np.zeros(last_step_size) + for order, _, __, ___ in execute_result: + idx, _ = get_day_min_idx_range(order.start_time, order.end_time, "1min", REG_CN) + exec_vol[idx - last_step_range[0]] = order.deal_amount + + if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: + assert exec_vol.sum() < self.position + 1, f"{exec_vol} too large" + exec_vol *= self.position / (exec_vol.sum()) + + market_volume = cast( + IndexData, + self.exchange.get_volume( + self.order.stock_id, + pd.Timestamp(start_time), + pd.Timestamp(end_time), + method=None, + ), + ) + market_price = cast( + IndexData, + self.exchange.get_deal_price( + self.order.stock_id, + pd.Timestamp(start_time), + pd.Timestamp(end_time), + method=None, + direction=self.order.direction, + ), + ) + market_price = fill_missing_data(np.array(market_price, dtype=float).reshape(-1)) + market_volume = fill_missing_data(np.array(market_volume, dtype=float).reshape(-1)) + + assert market_price.shape == market_volume.shape == exec_vol.shape + + # Get data from the current level executor's indicator + current_trade_account = self.executor.trade_account + current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + self.history_exec = dataframe_append( + self.history_exec, + self._collect_multi_order_metric( + order=self.order, + datetime=_get_all_timestamps(start_time, end_time, include_end=True), + market_vol=market_volume, + market_price=market_price, + exec_vol=exec_vol, + pa=current_df.iloc[-1]["pa"], + ), + ) + + self.history_steps = dataframe_append( + self.history_steps, + [ + self._collect_single_order_metric( + self.order, + self.cur_time, + market_volume, + market_price, + exec_vol.sum(), + exec_vol, + ), + ], + ) + + # Do this at the end + self.position -= exec_vol.sum() + + self.cur_time = self._next_time() + + def generate_metrics_after_done(self) -> None: + """Generate metrics once the upper level execution is done""" + + self.metrics = self._collect_single_order_metric( + self.order, + self.backtest_data.ticks_index[0], # start time + self.history_exec["market_volume"], + self.history_exec["market_price"], + self.history_steps["amount"].sum(), + self.history_exec["deal_amount"], + ) + + def _collect_multi_order_metric( + self, + order: Order, + datetime: pd.DatetimeIndex, + market_vol: np.ndarray, + market_price: np.ndarray, + exec_vol: np.ndarray, + pa: float, + ) -> SAOEMetrics: + return SAOEMetrics( + # It should have the same keys with SAOEMetrics, + # but the values do not necessarily have the annotated type. + # Some values could be vectorized (e.g., exec_vol). + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol, + market_price=market_price, + amount=exec_vol, + inner_amount=exec_vol, + deal_amount=exec_vol, + trade_price=market_price, + trade_value=market_price * exec_vol, + position=self.position - np.cumsum(exec_vol), + ffr=exec_vol / order.amount, + pa=pa, + ) + + def _collect_single_order_metric( + self, + order: Order, + datetime: pd.Timestamp, + market_vol: np.ndarray, + market_price: np.ndarray, + amount: float, # intended to trade such amount + exec_vol: np.ndarray, + ) -> SAOEMetrics: + assert len(market_vol) == len(market_price) == len(exec_vol) + + if np.abs(np.sum(exec_vol)) < EPS: + exec_avg_price = 0.0 + else: + exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + if hasattr(exec_avg_price, "item"): # could be numpy scalar + exec_avg_price = exec_avg_price.item() # type: ignore + + exec_sum = exec_vol.sum() + return SAOEMetrics( + stock_id=order.stock_id, + datetime=datetime, + direction=order.direction, + market_volume=market_vol.sum(), + market_price=market_price.mean() if len(market_price) > 0 else np.nan, + amount=amount, + inner_amount=exec_sum, + deal_amount=exec_sum, # in this simulator, there's no other restrictions + trade_price=exec_avg_price, + trade_value=float(np.sum(market_price * exec_vol)), + position=self.position - exec_sum, + ffr=float(exec_sum / order.amount), + pa=price_advantage(exec_avg_price, self.twap_price, order.direction), + ) + + @property + def saoe_state(self) -> SAOEState: + return SAOEState( + order=self.order, + cur_time=self.cur_time, + cur_step=self.executor.trade_calendar.get_trade_step() - self.start_idx, + position=self.position, + history_exec=self.history_exec, + history_steps=self.history_steps, + metrics=self.metrics, + backtest_data=self.backtest_data, + ticks_per_step=self.ticks_per_step, + ticks_index=self.backtest_data.ticks_index, + ticks_for_order=self.backtest_data.ticks_for_order, + ) class SAOEStrategy(RLStrategy): @@ -30,7 +292,7 @@ class SAOEStrategy(RLStrategy): def __init__( self, - policy: object, # TODO: add accurate typehint later. + policy: BasePolicy, outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, @@ -47,11 +309,17 @@ def __init__( self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {} self._last_step_range = (0, 0) - def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter: + def _create_qlib_backtest_adapter( + self, + order: Order, + trade_decision: BaseTradeDecision, + trade_range: TradeRange, + ) -> SAOEStateAdapter: backtest_data = load_backtest_data(order, self.trade_exchange, trade_range) return SAOEStateAdapter( order=order, + trade_decision=trade_decision, executor=self.executor, exchange=self.trade_exchange, ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN), @@ -71,7 +339,9 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) - self.adapter_dict = {} for decision in outer_trade_decision.get_decision(): order = cast(Order, decision) - self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range) + self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter( + order, outer_trade_decision, trade_range + ) def get_saoe_state_by_order(self, order: Order) -> SAOEState: return self.adapter_dict[order.key_by_day].saoe_state @@ -166,11 +436,10 @@ def __init__( policy: dict | BasePolicy, state_interpreter: dict | StateInterpreter, action_interpreter: dict | ActionInterpreter, - network: object = None, # TODO: add accurate typehint later. + network: dict | torch.nn.Module | None = None, outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, - backtest: bool = False, **kwargs: Any, ) -> None: super(SAOEIntStrategy, self).__init__( @@ -181,8 +450,6 @@ def __init__( **kwargs, ) - self._backtest = backtest - self._state_interpreter: StateInterpreter = init_instance_by_config( state_interpreter, accept_types=StateInterpreter, @@ -221,21 +488,9 @@ def __init__( if self._policy is not None: self._policy.eval() - def set_env(self, env: BaseEnvWrapper) -> None: - # TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper. - # We should decompose the interpreters with EnvWrapper in the future and we should remove this method - # after that. - - self._env = env - self._state_interpreter.env = self._action_interpreter.env = self._env - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) - # In backtest, env.reset() needs to be manually called since there is no outer trainer to call it - if self._backtest: - self._env.reset() - def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: assert hasattr(self.outer_trade_decision, "order_list") @@ -268,10 +523,6 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)] - # In backtest, env.step() needs to be manually called since there is no outer trainer to call it - if self._backtest: - self._env.step(None) - oh = self.trade_exchange.get_order_helper() order_list = [] for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols): diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index 66a185447d..7573b33911 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -7,7 +7,7 @@ import copy from contextlib import AbstractContextManager, contextmanager from pathlib import Path -from typing import Any, Dict, Iterable, List, Sequence, TypeVar, cast +from typing import Any, Dict, Iterable, List, OrderedDict, Sequence, TypeVar, cast import torch @@ -152,6 +152,13 @@ def state_dict(self) -> dict: "metrics": self.metrics, } + @staticmethod + def get_policy_state_dict(ckpt_path: Path) -> OrderedDict: + state_dict = torch.load(ckpt_path, map_location="cpu") + if "vessel" in state_dict: + state_dict = state_dict["vessel"]["policy"] + return state_dict + def load_state_dict(self, state_dict: dict) -> None: """Load all states into current trainer.""" self.vessel.load_state_dict(state_dict["vessel"]) diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index f082f3b013..e0c009b7bd 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -48,24 +48,9 @@ class EnvWrapperStatus(TypedDict): reward_history: list -class BaseEnvWrapper( +class EnvWrapper( gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], -): - """Base env wrapper for RL environments. It has two implementations: - - EnvWrapper: Qlib-based RL environment used in training. - - CollectDataEnvWrapper: Dummy environment used in collect_data_loop. - """ - - def __init__(self) -> None: - self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) - - def render(self, mode: str = "human") -> None: - raise NotImplementedError("Render is not implemented in BaseEnvWrapper.") - - -class EnvWrapper( - BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): """Qlib-based RL environment, subclassing ``gym.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. @@ -129,8 +114,6 @@ def __init__( # 3. Avoid circular reference. # 4. When the components get serialized, we can throw away the env without any burden. # (though this part is not implemented yet) - super().__init__() - for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]: if obj is not None: obj.env = weakref.proxy(self) # type: ignore @@ -263,19 +246,5 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, fl info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) return obs, rew, done, info_dict - -class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]): - """Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop.""" - - def reset(self, **kwargs: Any) -> None: - self.status = EnvWrapperStatus( - cur_step=0, - done=False, - initial_state=None, - obs_history=[], - action_history=[], - reward_history=[], - ) - - def step(self, policy_action: Any = None, **kwargs: Any) -> None: - self.status["cur_step"] += 1 + def render(self, mode: str = "human") -> None: + raise NotImplementedError("Render is not implemented in EnvWrapper.") diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 55ede19b9b..5f62e77589 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -473,7 +473,8 @@ def _generate(self, **kwargs): self.save(**{f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal}) + self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]}) + self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]}) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -511,7 +512,7 @@ def _generate(self, **kwargs): if _analysis_freq not in indicator_dict: warnings.warn(f"the freq {_analysis_freq} indicator is not found") else: - indicators_normal = indicator_dict.get(_analysis_freq) + indicators_normal = indicator_dict.get(_analysis_freq)[0] if self.indicator_analysis_method is None: analysis_df = indicator_analysis(indicators_normal) else: diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py index f0497bc91f..2e30f1a3cb 100644 --- a/tests/backtest/test_file_strategy.py +++ b/tests/backtest/test_file_strategy.py @@ -107,7 +107,7 @@ def test_file_str(self): ) # ffr valid - ffr_dict = indicator_dict["1day"]["ffr"].to_dict() + ffr_dict = indicator_dict["1day"][0]["ffr"].to_dict() ffr_dict = {str(date).split()[0]: ffr_dict[date] for date in ffr_dict} assert np.isclose(ffr_dict["2020-01-03"], dealt_num_for_1000 / 1000) assert np.isclose(ffr_dict["2020-01-06"], 0) diff --git a/tests/backtest/test_high_freq_trading.py b/tests/backtest/test_high_freq_trading.py index 21bc4e0d47..fd934914d8 100644 --- a/tests/backtest/test_high_freq_trading.py +++ b/tests/backtest/test_high_freq_trading.py @@ -125,7 +125,7 @@ def test_trading(self): # NOTE: please refer to the docs of format_decisions # NOTE: `"track_data": True,` is very NECESSARY for collecting the decision!!!!! f_dec = format_decisions(decisions) - print(indicator["1day"]) + print(indicator["1day"][0]) if __name__ == "__main__": diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index 92ad9c0583..382609e5e1 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -7,11 +7,11 @@ import pandas as pd import pytest -from qlib.backtest.decision import Order, OrderDir, TradeRangeByTime + +from qlib.backtest.decision import Order, OrderDir from qlib.backtest.executor import SimulatorExecutor from qlib.rl.order_execution import CategoricalActionInterpreter from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution -from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper TOTAL_POSITION = 2100.0 @@ -183,8 +183,6 @@ def test_interpreter() -> None: order = get_order() simulator = get_simulator(order) interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) - interpreter_action.env = CollectDataEnvWrapper() - interpreter_action.env.reset() NUM_STEPS = 7 state = simulator.get_state() diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 22bd039096..32d6b4d6e4 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -20,7 +20,6 @@ from qlib.rl.order_execution import * from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus -from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") @@ -186,10 +185,6 @@ class EmulateEnvWrapper(NamedTuple): assert np.sum(obs["data_processed"][60:]) == 0 # second step: action - interpreter_action.env = CollectDataEnvWrapper() - interpreter_action_twap.env = CollectDataEnvWrapper() - interpreter_action.env.reset() - interpreter_action_twap.env.reset() action = interpreter_action(simulator.get_state(), 1) assert action == 15 / 20 @@ -260,8 +255,6 @@ def test_twap_strategy(finite_env_type): state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) action_interp = TwapRelativeActionInterpreter() - action_interp.env = CollectDataEnvWrapper() - action_interp.env.reset() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -291,8 +284,6 @@ def test_cn_ppo_strategy(): state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) - action_interp.env = CollectDataEnvWrapper() - action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) @@ -324,8 +315,6 @@ def test_ppo_train(): state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) - action_interp.env = CollectDataEnvWrapper() - action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)