From a3ff5d908bab7f9ae30269c56c1ee06c364611a6 Mon Sep 17 00:00:00 2001 From: Chenglong Chen Date: Sat, 11 Mar 2023 14:18:36 +0800 Subject: [PATCH] move policy --- .../continuous/{meta_ddpg.py => ddpg.py} | 0 .../policy/continuous/{meta_td3.py => td3.py} | 0 .../policy/discrete/meta_ppo.py | 108 ------------------ .../rl/order_execution/policy/discrete/ppo.py | 102 ++++++++++++++++- longcapital/rl/order_execution/strategy.py | 5 +- 5 files changed, 103 insertions(+), 112 deletions(-) rename longcapital/rl/order_execution/policy/continuous/{meta_ddpg.py => ddpg.py} (100%) rename longcapital/rl/order_execution/policy/continuous/{meta_td3.py => td3.py} (100%) delete mode 100644 longcapital/rl/order_execution/policy/discrete/meta_ppo.py diff --git a/longcapital/rl/order_execution/policy/continuous/meta_ddpg.py b/longcapital/rl/order_execution/policy/continuous/ddpg.py similarity index 100% rename from longcapital/rl/order_execution/policy/continuous/meta_ddpg.py rename to longcapital/rl/order_execution/policy/continuous/ddpg.py diff --git a/longcapital/rl/order_execution/policy/continuous/meta_td3.py b/longcapital/rl/order_execution/policy/continuous/td3.py similarity index 100% rename from longcapital/rl/order_execution/policy/continuous/meta_td3.py rename to longcapital/rl/order_execution/policy/continuous/td3.py diff --git a/longcapital/rl/order_execution/policy/discrete/meta_ppo.py b/longcapital/rl/order_execution/policy/discrete/meta_ppo.py deleted file mode 100644 index 4b36ed3..0000000 --- a/longcapital/rl/order_execution/policy/discrete/meta_ppo.py +++ /dev/null @@ -1,108 +0,0 @@ -from pathlib import Path -from typing import Any, List, Optional, Union - -import gym -import numpy as np -import torch -from longcapital.rl.utils.net.common import MetaNet -from longcapital.rl.utils.net.discrete import MetaActor, MetaCritic -from qlib.rl.order_execution.policy import Trainer, auto_device, chain_dedup, set_weight -from tianshou.data import Batch -from tianshou.policy import PPOPolicy - - -class MetaPPO(PPOPolicy): - def __init__( - self, - obs_space: gym.Space, - action_space: gym.Space, - softmax_output: bool = False, - sigmoid_output: bool = True, - hidden_sizes: List[int] = [32, 16, 8], - lr: float = 1e-4, - weight_decay: float = 0.0, - discount_factor: float = 1.0, - max_grad_norm: float = 100.0, - reward_normalization: bool = True, - eps_clip: float = 0.3, - value_clip: bool = True, - vf_coef: float = 1.0, - gae_lambda: float = 1.0, - max_batch_size: int = 256, - deterministic_eval: bool = True, - weight_file: Optional[Path] = None, - ) -> None: - assert not (softmax_output and sigmoid_output) - - net = MetaNet(obs_space.shape, hidden_sizes=hidden_sizes, self_attn=True) - actor = MetaActor( - net, - action_space.shape, - softmax_output=softmax_output, - sigmoid_output=sigmoid_output, - device=auto_device(net), - ).to(auto_device(net)) - - net = MetaNet( - obs_space.shape, - action_space.shape, - hidden_sizes=hidden_sizes, - attn_pooling=True, - ) - critic = MetaCritic(net, device=auto_device(net)).to(auto_device(net)) - - optimizer = torch.optim.Adam( - chain_dedup(actor.parameters(), critic.parameters()), - lr=lr, - weight_decay=weight_decay, - ) - - super().__init__( - actor, - critic, - optimizer, - torch.distributions.Bernoulli, - discount_factor=discount_factor, - max_grad_norm=max_grad_norm, - reward_normalization=reward_normalization, - eps_clip=eps_clip, - value_clip=value_clip, - vf_coef=vf_coef, - gae_lambda=gae_lambda, - max_batchsize=max_batch_size, - deterministic_eval=deterministic_eval, - observation_space=obs_space, - action_space=action_space, - ) - if weight_file is not None: - set_weight(self, Trainer.get_policy_state_dict(weight_file)) - - def forward( - self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs: Any, - ) -> Batch: - """Compute action over the given batch data. - :return: A :class:`~tianshou.data.Batch` which has 4 keys: - * ``act`` the action. - * ``logits`` the network's raw output. - * ``dist`` the action distribution. - * ``state`` the hidden state. - .. seealso:: - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - if isinstance(logits, tuple): - dist = self.dist_fn(*logits) - else: - dist = self.dist_fn(logits) - if self._deterministic_eval and not self.training: - if self.action_type == "discrete": - act = (logits > 0.5).float() - elif self.action_type == "continuous": - act = logits[0] - else: - act = dist.sample() - return Batch(logits=logits, act=act, state=hidden, dist=dist) diff --git a/longcapital/rl/order_execution/policy/discrete/ppo.py b/longcapital/rl/order_execution/policy/discrete/ppo.py index 0ffa298..fcc85ce 100644 --- a/longcapital/rl/order_execution/policy/discrete/ppo.py +++ b/longcapital/rl/order_execution/policy/discrete/ppo.py @@ -1,10 +1,13 @@ from pathlib import Path -from typing import List, Optional +from typing import Any, List, Optional, Union import gym +import numpy as np import torch from longcapital.rl.utils.net.common import MetaNet +from longcapital.rl.utils.net.discrete import MetaActor, MetaCritic from qlib.rl.order_execution.policy import Trainer, auto_device, chain_dedup, set_weight +from tianshou.data import Batch from tianshou.policy import PPOPolicy from tianshou.utils.net.discrete import Actor, Critic @@ -66,3 +69,100 @@ def __init__( ) if weight_file is not None: set_weight(self, Trainer.get_policy_state_dict(weight_file)) + + +class MetaPPO(PPOPolicy): + def __init__( + self, + obs_space: gym.Space, + action_space: gym.Space, + softmax_output: bool = False, + sigmoid_output: bool = True, + hidden_sizes: List[int] = [32, 16, 8], + lr: float = 1e-4, + weight_decay: float = 0.0, + discount_factor: float = 1.0, + max_grad_norm: float = 100.0, + reward_normalization: bool = True, + eps_clip: float = 0.3, + value_clip: bool = True, + vf_coef: float = 1.0, + gae_lambda: float = 1.0, + max_batch_size: int = 256, + deterministic_eval: bool = True, + weight_file: Optional[Path] = None, + ) -> None: + assert not (softmax_output and sigmoid_output) + + net = MetaNet(obs_space.shape, hidden_sizes=hidden_sizes, self_attn=True) + actor = MetaActor( + net, + action_space.shape, + softmax_output=softmax_output, + sigmoid_output=sigmoid_output, + device=auto_device(net), + ).to(auto_device(net)) + + net = MetaNet( + obs_space.shape, + action_space.shape, + hidden_sizes=hidden_sizes, + attn_pooling=True, + ) + critic = MetaCritic(net, device=auto_device(net)).to(auto_device(net)) + + optimizer = torch.optim.Adam( + chain_dedup(actor.parameters(), critic.parameters()), + lr=lr, + weight_decay=weight_decay, + ) + + super().__init__( + actor, + critic, + optimizer, + torch.distributions.Bernoulli, + discount_factor=discount_factor, + max_grad_norm=max_grad_norm, + reward_normalization=reward_normalization, + eps_clip=eps_clip, + value_clip=value_clip, + vf_coef=vf_coef, + gae_lambda=gae_lambda, + max_batchsize=max_batch_size, + deterministic_eval=deterministic_eval, + observation_space=obs_space, + action_space=action_space, + ) + if weight_file is not None: + set_weight(self, Trainer.get_policy_state_dict(weight_file)) + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data. + :return: A :class:`~tianshou.data.Batch` which has 4 keys: + * ``act`` the action. + * ``logits`` the network's raw output. + * ``dist`` the action distribution. + * ``state`` the hidden state. + .. seealso:: + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + logits, hidden = self.actor(batch.obs, state=state, info=batch.info) + if isinstance(logits, tuple): + dist = self.dist_fn(*logits) + else: + dist = self.dist_fn(logits) + if self._deterministic_eval and not self.training: + if self.action_type == "discrete": + act = (logits > 0.5).float() + elif self.action_type == "continuous": + act = logits[0] + else: + act = dist.sample() + return Batch(logits=logits, act=act, state=hidden, dist=dist) diff --git a/longcapital/rl/order_execution/strategy.py b/longcapital/rl/order_execution/strategy.py index 1b79893..e35017a 100644 --- a/longcapital/rl/order_execution/strategy.py +++ b/longcapital/rl/order_execution/strategy.py @@ -12,9 +12,8 @@ WeightStrategyAction, WeightStrategyActionInterpreter, ) -from longcapital.rl.order_execution.policy.continuous.meta_td3 import MetaTD3 -from longcapital.rl.order_execution.policy.discrete.meta_ppo import MetaPPO -from longcapital.rl.order_execution.policy.discrete.ppo import PPO +from longcapital.rl.order_execution.policy.continuous.td3 import MetaTD3 +from longcapital.rl.order_execution.policy.discrete.ppo import PPO, MetaPPO from longcapital.rl.order_execution.state import TradeStrategyState from qlib.backtest.decision import TradeDecisionWO from qlib.backtest.position import Position