代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/stock_robot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import gym
import math
from gym import spaces
import numpy as np
import cvt_helper as cvt
from stock_info import StockSceneInfo
from stock_info import StockLearnInfo
import const as cst
import random
from data_center import DataCenter
class StockEnv(gym.Env):
"""A stock trading environment for OpenAI gym"""
metadata = {'render.modes': ['human']}
def __init__(self,all_year):
super(StockEnv, self).__init__()
# Actions of the format Buy x%, Sell x%, Hold, etc.
self.action_space = spaces.Box(low=np.array([np.float32(0),np.float32(0)]), high=np.array([np.float32(1), np.float32(1)]), shape=(2,),dtype=np.float32)
# self.action_space = spaces.Discrete(low=np.array([0, 0, 0, 0]), high=np.array([1 , 1 , 1, 1]), dtype=np.uint8)
# Prices contains the OHCL values for the last five prices
self.observation_space = spaces.Box(low=np.float32(0), high=np.float32(1e7), shape=(cst.STOCK_FIELD_COUNT*cst.HISTORY_DATA_COUNT,), dtype=np.float32)
#holder_obs = spaces.Box(low=np.float32(0), high=np.float32(1),shape=(4,), dtype=np.float32)
#scene_obs = spaces.Box(low=np.float32(0), high=np.float32(1), shape=(POOL_STOCK_COUNT,6),dtype=np.float32)
#self.observation_space = spaces.Dict({"holder_obs":holder_obs,"sence_obs":scene_obs})
self.data_center = DataCenter()
self.all_stock = self.data_center.query_all_stock()
self.current_step = cst.HISTORY_DATA_COUNT
self.all_year = all_year
def step(self, action):
reward = self._take_action(action)
self.current_step += 1
done = self._is_done()
observation = np.zeros(shape=(cst.HISTORY_DATA_COUNT*cst.STOCK_FIELD_COUNT,),dtype=np.float32);
if done == False:
self.frame_stock_data = self._get_frame_data()
observation = cvt.get_obs(self.frame_stock_data.SceneInfo)
return observation, reward, done, {}
def reset(self):
self.current_step = cst.HISTORY_DATA_COUNT
self.current_stock_data = []
while self._is_done():
stock_index = random.randint(0,len(self.all_stock)-1)
year_index =random.randint(0,len(self.all_year)-1)
code = self.all_stock[stock_index]
year = self.all_year[year_index]
self.current_stock_data = self.data_center.get_stock_info(code,year)
self.frame_stock_data = self._get_frame_data()
observation = cvt.get_obs(self.frame_stock_data.SceneInfo)
return observation
def _get_frame_data(self):
learn_data = StockLearnInfo()
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = self.current_stock_data[self.current_step-cst.HISTORY_DATA_COUNT+i]
learn_data.SceneInfo.append(scene)
next : StockSceneInfo = self.current_stock_data[self.current_step]
learn_data.NextHigh = next.High
learn_data.NextLow = next.Low
learn_data.NextStandard = next.Standard
return learn_data
def render(self, mode='human'):
#print(f'Money : {}')
pass
def close (self):
pass
def _is_done(self):
return self.current_step>=len(self.current_stock_data)-2
#执行对应的action
def _take_action(self, action):
high = action[0]
low = action[1]
stock : StockLearnInfo = self.frame_stock_data
high_price = cvt.rtov(high,stock.NextStandard)
low_price = cvt.rtov(low,stock.NextStandard)
high_variance = math.fabs(high_price-stock.NextHigh)/stock.NextHigh
low_variance = math.fabs(low_price-stock.NextLow)/stock.NextLow
return 2*cvt.PRICE_CHANGE_LIMIT - (high_variance+low_variance)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。