1 Star 0 Fork 16

aaron/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
stock_env.py 3.72 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 +08:00 . 1.6
import gym
import math
from gym import spaces
import numpy as np
import cvt_helper as cvt
from stock_info import StockSceneInfo
from stock_info import StockLearnInfo
import const as cst
import random
from data_center import DataCenter
class StockEnv(gym.Env):
"""A stock trading environment for OpenAI gym"""
metadata = {'render.modes': ['human']}
def __init__(self,all_year):
super(StockEnv, self).__init__()
# Actions of the format Buy x%, Sell x%, Hold, etc.
self.action_space = spaces.Box(low=np.array([np.float32(0),np.float32(0)]), high=np.array([np.float32(1), np.float32(1)]), shape=(2,),dtype=np.float32)
# self.action_space = spaces.Discrete(low=np.array([0, 0, 0, 0]), high=np.array([1 , 1 , 1, 1]), dtype=np.uint8)
# Prices contains the OHCL values for the last five prices
self.observation_space = spaces.Box(low=np.float32(0), high=np.float32(1e7), shape=(cst.STOCK_FIELD_COUNT*cst.HISTORY_DATA_COUNT,), dtype=np.float32)
#holder_obs = spaces.Box(low=np.float32(0), high=np.float32(1),shape=(4,), dtype=np.float32)
#scene_obs = spaces.Box(low=np.float32(0), high=np.float32(1), shape=(POOL_STOCK_COUNT,6),dtype=np.float32)
#self.observation_space = spaces.Dict({"holder_obs":holder_obs,"sence_obs":scene_obs})
self.data_center = DataCenter()
self.all_stock = self.data_center.query_all_stock()
self.current_step = cst.HISTORY_DATA_COUNT
self.all_year = all_year
def step(self, action):
reward = self._take_action(action)
self.current_step += 1
done = self._is_done()
observation = np.zeros(shape=(cst.HISTORY_DATA_COUNT*cst.STOCK_FIELD_COUNT,),dtype=np.float32);
if done == False:
self.frame_stock_data = self._get_frame_data()
observation = cvt.get_obs(self.frame_stock_data.SceneInfo)
return observation, reward, done, {}
def reset(self):
self.current_step = cst.HISTORY_DATA_COUNT
self.current_stock_data = []
while self._is_done():
stock_index = random.randint(0,len(self.all_stock)-1)
year_index =random.randint(0,len(self.all_year)-1)
code = self.all_stock[stock_index]
year = self.all_year[year_index]
self.current_stock_data = self.data_center.get_stock_info(code,year)
self.frame_stock_data = self._get_frame_data()
observation = cvt.get_obs(self.frame_stock_data.SceneInfo)
return observation
def _get_frame_data(self):
learn_data = StockLearnInfo()
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = self.current_stock_data[self.current_step-cst.HISTORY_DATA_COUNT+i]
learn_data.SceneInfo.append(scene)
next : StockSceneInfo = self.current_stock_data[self.current_step]
learn_data.NextHigh = next.High
learn_data.NextLow = next.Low
learn_data.NextStandard = next.Standard
return learn_data
def render(self, mode='human'):
#print(f'Money : {}')
pass
def close (self):
pass
def _is_done(self):
return self.current_step>=len(self.current_stock_data)-2
#执行对应的action
def _take_action(self, action):
high = action[0]
low = action[1]
stock : StockLearnInfo = self.frame_stock_data
high_price = cvt.rtov(high,stock.NextStandard)
low_price = cvt.rtov(low,stock.NextStandard)
high_variance = math.fabs(high_price-stock.NextHigh)/stock.NextHigh
low_variance = math.fabs(low_price-stock.NextLow)/stock.NextLow
return 2*cvt.PRICE_CHANGE_LIMIT - (high_variance+low_variance)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/alpha_aaron/stock_robot.git
git@gitee.com:alpha_aaron/stock_robot.git
alpha_aaron
stock_robot
stock_robot
master

搜索帮助