Skip to content

Commit

Permalink
Merge pull request #5 from ChenglongChen/dev
Browse files Browse the repository at this point in the history
move policy
  • Loading branch information
ChenglongChen authored Mar 11, 2023
2 parents eda38d1 + a3ff5d9 commit 541bdb0
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 112 deletions.
108 changes: 0 additions & 108 deletions longcapital/rl/order_execution/policy/discrete/meta_ppo.py

This file was deleted.

102 changes: 101 additions & 1 deletion longcapital/rl/order_execution/policy/discrete/ppo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path
from typing import List, Optional
from typing import Any, List, Optional, Union

import gym
import numpy as np
import torch
from longcapital.rl.utils.net.common import MetaNet
from longcapital.rl.utils.net.discrete import MetaActor, MetaCritic
from qlib.rl.order_execution.policy import Trainer, auto_device, chain_dedup, set_weight
from tianshou.data import Batch
from tianshou.policy import PPOPolicy
from tianshou.utils.net.discrete import Actor, Critic

Expand Down Expand Up @@ -66,3 +69,100 @@ def __init__(
)
if weight_file is not None:
set_weight(self, Trainer.get_policy_state_dict(weight_file))


class MetaPPO(PPOPolicy):
def __init__(
self,
obs_space: gym.Space,
action_space: gym.Space,
softmax_output: bool = False,
sigmoid_output: bool = True,
hidden_sizes: List[int] = [32, 16, 8],
lr: float = 1e-4,
weight_decay: float = 0.0,
discount_factor: float = 1.0,
max_grad_norm: float = 100.0,
reward_normalization: bool = True,
eps_clip: float = 0.3,
value_clip: bool = True,
vf_coef: float = 1.0,
gae_lambda: float = 1.0,
max_batch_size: int = 256,
deterministic_eval: bool = True,
weight_file: Optional[Path] = None,
) -> None:
assert not (softmax_output and sigmoid_output)

net = MetaNet(obs_space.shape, hidden_sizes=hidden_sizes, self_attn=True)
actor = MetaActor(
net,
action_space.shape,
softmax_output=softmax_output,
sigmoid_output=sigmoid_output,
device=auto_device(net),
).to(auto_device(net))

net = MetaNet(
obs_space.shape,
action_space.shape,
hidden_sizes=hidden_sizes,
attn_pooling=True,
)
critic = MetaCritic(net, device=auto_device(net)).to(auto_device(net))

optimizer = torch.optim.Adam(
chain_dedup(actor.parameters(), critic.parameters()),
lr=lr,
weight_decay=weight_decay,
)

super().__init__(
actor,
critic,
optimizer,
torch.distributions.Bernoulli,
discount_factor=discount_factor,
max_grad_norm=max_grad_norm,
reward_normalization=reward_normalization,
eps_clip=eps_clip,
value_clip=value_clip,
vf_coef=vf_coef,
gae_lambda=gae_lambda,
max_batchsize=max_batch_size,
deterministic_eval=deterministic_eval,
observation_space=obs_space,
action_space=action_space,
)
if weight_file is not None:
set_weight(self, Trainer.get_policy_state_dict(weight_file))

def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``dist`` the action distribution.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
if self._deterministic_eval and not self.training:
if self.action_type == "discrete":
act = (logits > 0.5).float()
elif self.action_type == "continuous":
act = logits[0]
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)
5 changes: 2 additions & 3 deletions longcapital/rl/order_execution/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
WeightStrategyAction,
WeightStrategyActionInterpreter,
)
from longcapital.rl.order_execution.policy.continuous.meta_td3 import MetaTD3
from longcapital.rl.order_execution.policy.discrete.meta_ppo import MetaPPO
from longcapital.rl.order_execution.policy.discrete.ppo import PPO
from longcapital.rl.order_execution.policy.continuous.td3 import MetaTD3
from longcapital.rl.order_execution.policy.discrete.ppo import PPO, MetaPPO
from longcapital.rl.order_execution.state import TradeStrategyState
from qlib.backtest.decision import TradeDecisionWO
from qlib.backtest.position import Position
Expand Down

0 comments on commit 541bdb0

Please sign in to comment.