Skip to content

Commit

Permalink
Refine RL todos (microsoft#1332)
Browse files Browse the repository at this point in the history
* Refine several todos

* CI issues

* Remove Dropna limitation of `quote_df` in Exchange  (microsoft#1334)

* Remove Dropna limitation of `quote_df` of Exchange

* Impreove docstring

* Fix type error when expression is specified (microsoft#1335)

* Refine fill_missing_data()

* Remove several TODO comments

* Add back env for interpreters

* Change Literal import

* Resolve PR comments

* Move  to SAOEState

* Add Trainer.get_policy_state_dict()

* Mypy issue

Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
  • Loading branch information
lihuoran and you-n-g authored Nov 10, 2022
1 parent 6d3265b commit 8307be5
Show file tree
Hide file tree
Showing 20 changed files with 461 additions and 530 deletions.
12 changes: 5 additions & 7 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd

from .account import Account
from .report import Indicator, PortfolioMetrics

if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
Expand All @@ -20,7 +19,7 @@
from ..config import C
from ..log import get_module_logger
from ..utils import init_instance_by_config
from .backtest import backtest_loop, collect_data_loop
from .backtest import INDICATOR_METRIC, PORT_METRIC, backtest_loop, collect_data_loop
from .decision import Order
from .exchange import Exchange
from .utils import CommonInfrastructure
Expand Down Expand Up @@ -223,7 +222,7 @@ def backtest(
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution
Expand Down Expand Up @@ -256,9 +255,9 @@ def backtest(
Returns
-------
portfolio_metrics_dict: Dict[PortfolioMetrics]
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator_dict: Dict[Indicator]
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
It is organized in a dict format
Expand All @@ -273,8 +272,7 @@ def backtest(
exchange_kwargs,
pos_type=pos_type,
)
portfolio_metrics, indicator = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return portfolio_metrics, indicator
return backtest_loop(start_time, end_time, trade_strategy, trade_executor)


def collect_data(
Expand Down
46 changes: 27 additions & 19 deletions qlib/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union, cast
from typing import Dict, TYPE_CHECKING, Generator, Optional, Tuple, Union, cast

import pandas as pd

from qlib.backtest.decision import BaseTradeDecision
from qlib.backtest.report import Indicator, PortfolioMetrics
from qlib.backtest.report import Indicator

if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
Expand All @@ -19,30 +19,35 @@
from ..utils.time import Freq


PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]


def backtest_loop(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
trade_strategy: BaseStrategy,
trade_executor: BaseExecutor,
) -> Tuple[PortfolioMetrics, Indicator]:
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
please refer to the docs of `collect_data_loop`
Returns
-------
portfolio_metrics: PortfolioMetrics
portfolio_dict: PORT_METRIC
it records the trading portfolio_metrics information
indicator: Indicator
indicator_dict: INDICATOR_METRIC
it computes the trading indicator
"""
return_value: dict = {}
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
pass

portfolio_metrics = cast(PortfolioMetrics, return_value.get("portfolio_metrics"))
indicator = cast(Indicator, return_value.get("indicator"))
return portfolio_metrics, indicator
portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict"))
indicator_dict = cast(INDICATOR_METRIC, return_value.get("indicator_dict"))

return portfolio_dict, indicator_dict


def collect_data_loop(
Expand Down Expand Up @@ -89,14 +94,17 @@ def collect_data_loop(

if return_value is not None:
all_executors = trade_executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return_value.update({"portfolio_metrics": all_portfolio_metrics, "indicator": all_indicators})

portfolio_dict: PORT_METRIC = {}
indicator_dict: INDICATOR_METRIC = {}

for executor in all_executors:
key = "{}{}".format(*Freq.parse(executor.time_per_step))
if executor.trade_account.is_port_metr_enabled():
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()

indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
indicator_obj = executor.trade_account.get_trade_indicator()
indicator_dict[key] = (indicator_df, indicator_obj)

return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict})
63 changes: 51 additions & 12 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@


class Exchange:
# `quote_df` is a pd.DataFrame class that contains basic information for backtesting
# After some processing, the data will later be maintained by `quote_cls` object for faster data retriving.
# Some conventions for `quote_df`
# - $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# - $factor is for rounding to the trading unit;
# - if any $factor is missing when $close exists, trading unit rounding will be disabled
quote_df: pd.DataFrame

def __init__(
self,
freq: str = "day",
Expand Down Expand Up @@ -159,6 +168,7 @@ def __init__(
self.codes = codes
# Necessary fields
# $close is for calculating the total value at end of each day.
# - if $close is None, the stock on that day is reguarded as suspended.
# $factor is for rounding to the trading unit
# $change is for calculating the limit of the stock

Expand Down Expand Up @@ -199,7 +209,7 @@ def get_quote_from_qlib(self) -> None:
self.end_time,
freq=self.freq,
disk_cache=True,
).dropna(subset=["$close"])
)
self.quote_df.columns = self.all_fields

# check buy_price data and sell_price data
Expand All @@ -209,7 +219,7 @@ def get_quote_from_qlib(self) -> None:
self.logger.warning("{} field data contains nan.".format(pstr))

# update trade_w_adj_price
if self.quote_df["$factor"].isna().any():
if (self.quote_df["$factor"].isna() & ~self.quote_df["$close"].isna()).any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
Expand Down Expand Up @@ -245,9 +255,9 @@ def get_quote_from_qlib(self) -> None:
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)

LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
LT_NONE = "none" # none
LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression.
LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold`
LT_NONE = "none" # none: there is no trading limitation

def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
"""get limit type"""
Expand All @@ -261,20 +271,25 @@ def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")

def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
# $close is may contains NaN, the nan indicates that the stock is not tradable at that timestamp
suspended = self.quote_df["$close"].isna()
# check limit_threshold
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = False
self.quote_df["limit_sell"] = False
self.quote_df["limit_buy"] = suspended
self.quote_df["limit_sell"] = suspended
elif limit_type == self.LT_TP_EXP:
# set limit
limit_threshold = cast(tuple, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
# astype bool is necessary, because quote_df is an expression and could be float
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended
elif limit_type == self.LT_FLT:
limit_threshold = cast(float, limit_threshold)
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended
self.quote_df["limit_sell"] = (
self.quote_df["$change"].le(-limit_threshold) | suspended
) # pylint: disable=E1130

@staticmethod
def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]:
Expand Down Expand Up @@ -338,8 +353,18 @@ def check_stock_limit(
- if direction is None, check if tradable for buying and selling.
- if direction == Order.BUY, check the if tradable for buying
- if direction == Order.SELL, check the sell limit for selling.
Returns
-------
True: the trading of the stock is limted (maybe hit the highest/lowest price), hence the stock is not tradable
False: the trading of the stock is not limited, hence the stock may be tradable
"""
# NOTE:
# **all** is used when checking limitation.
# For example, the stock trading is limited in a day if every miniute is limited in a day if every miniute is limited.
if direction is None:
# The trading limitation is related to the trading direction
# if the direction is not provided, then any limitation from buy or sell will result in trading limitation
buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")
return bool(buy_limit or sell_limit)
Expand All @@ -356,10 +381,24 @@ def check_stock_suspended(
start_time: pd.Timestamp,
end_time: pd.Timestamp,
) -> bool:
"""if stock is suspended(hence not tradable), True will be returned"""
# is suspended
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
# suspended stocks are represented by None $close stock
# The $close may contains NaN,
close = self.quote.get_data(stock_id, start_time, end_time, "$close")
if close is None:
# if no close record exists
return True
elif isinstance(close, IndexData):
# **any** non-NaN $close represents trading opportunity may exists
# if all returned is nan, then the stock is suspended
return cast(bool, cast(IndexData, close).isna().all())
else:
# it is single value, make sure is is not None
return np.isnan(close)
else:
# if the stock is not in the stock list, then it is not tradable and regarded as suspended
return True

def is_stock_tradable(
Expand Down
Loading

0 comments on commit 8307be5

Please sign in to comment.