1 Star 0 Fork 16

aaron/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
optimize_sac.py 3.59 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 +08:00 . 1.6
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
from stock_env import StockEnv
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.noise import NormalActionNoise
import torch as th
import numpy as np
import optuna
from stable_baselines3.common.evaluation import evaluate_policy
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/sac"
LEARN_TIMES = 10000
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(StockEnv(range(2011,2021)))
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
def optimize_params(trial,actions):
na_num = trial.suggest_int('na_num', 4, 48)
net_arch = []
for i in range(na_num):
net_arch.append(trial.suggest_categorical(str(i), [32,64,128,256,512,768,1024,1280,1536,1792,2048]))
all_fn = [
th.nn.ReLU,
th.nn.RReLU,
th.nn.Hardtanh,
th.nn.ReLU6,
th.nn.Sigmoid,
th.nn.Hardsigmoid,
th.nn.Tanh,
th.nn.SiLU,
th.nn.Mish,
th.nn.Hardswish,
th.nn.ELU,
th.nn.CELU,
th.nn.SELU,
th.nn.GLU,
th.nn.GELU,
th.nn.Hardshrink,
th.nn.LeakyReLU,
th.nn.LogSigmoid,
th.nn.Softplus,
th.nn.Softshrink,
th.nn.MultiheadAttention,
th.nn.PReLU,
th.nn.Softsign,
th.nn.Tanhshrink,
th.nn.Softmin,
th.nn.Softmax,
th.nn.LogSoftmax
]
fn_index = trial.suggest_int('fn_index', 0,len(all_fn)-1)
#sigma=trial.suggest_uniform('sigma', 0.01, 0.2)
return {
'gamma':trial.suggest_loguniform('gamma', 0.8, 0.99),
'batch_size':trial.suggest_categorical("batch_size", [16, 32, 64, 128, 256, 512, 1024, 2048]),
'buffer_size' : trial.suggest_categorical("buffer_size", [int(500000), int(1000000), int(2000000)]),
'learning_starts':trial.suggest_categorical("learning_starts", [1, 10, 100, 200, 2000]),
'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
'tau':trial.suggest_categorical("tau", [0.001, 0.005, 0.01, 0.02, 0.05, 0.08, 0.1, 0.2]),
#'action_noise':NormalActionNoise(mean=np.zeros(actions), sigma=sigma * np.ones(actions)),
'train_freq' : trial.suggest_categorical("train_freq", [1, 4, 8, 16]),
'policy_kwargs':dict(
activation_fn=all_fn[fn_index],
net_arch=net_arch
)
}
def optimize_agent(trial):
try:
num_cpu = 64
# Create the vectorized environment
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_params(trial,env.action_space.shape[-1])
model = SAC('MlpPolicy', env,**model_params)
model.learn(total_timesteps=LEARN_TIMES)
#model.save(MODEL_PATH+'/trial_{}'.format(trial.number))
mean_reward, std_reward = evaluate_policy(model, Monitor(StockEnv([2022])))
print("mean_reward",mean_reward,std_reward)
return mean_reward
except Exception as e:
print(e)
return -10000
if __name__ == '__main__':
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=100,gc_after_trial=True)
print(study.best_params)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/alpha_aaron/stock_robot.git
git@gitee.com:alpha_aaron/stock_robot.git
alpha_aaron
stock_robot
stock_robot
master

搜索帮助