Ai
1 Star 0 Fork 0

ysl/DQN-tensorflow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 2.07 KB
一键复制 编辑 原始数据 按行查看 历史
from __future__ import print_function
import random
import tensorflow as tf
from dqn.agent import Agent
from dqn.environment import GymEnvironment, SimpleGymEnvironment
from config import get_config
flags = tf.app.flags
# Model
flags.DEFINE_string('model', 'm1', 'Type of model')
flags.DEFINE_boolean('dueling', False, 'Whether to use dueling deep q-network')
flags.DEFINE_boolean('double_q', False, 'Whether to use double q-learning')
# Environment
flags.DEFINE_string('env_name', 'Breakout-v0', 'The name of gym environment to use')
flags.DEFINE_integer('action_repeat', 4, 'The number of action to be repeated')
# Etc
flags.DEFINE_boolean('use_gpu', True, 'Whether to use gpu or not')
flags.DEFINE_string('gpu_fraction', '1/1', 'idx / # of gpu fraction e.g. 1/3, 2/3, 3/3')
flags.DEFINE_boolean('display', False, 'Whether to do display the game screen or not')
flags.DEFINE_boolean('is_train', True, 'Whether to do training or testing')
flags.DEFINE_integer('random_seed', 123, 'Value of random seed')
FLAGS = flags.FLAGS
# Set random seed
tf.set_random_seed(FLAGS.random_seed)
random.seed(FLAGS.random_seed)
if FLAGS.gpu_fraction == '':
raise ValueError("--gpu_fraction should be defined")
def calc_gpu_fraction(fraction_string):
idx, num = fraction_string.split('/')
idx, num = float(idx), float(num)
fraction = 1 / (num - idx + 1)
print(" [*] GPU : %.4f" % fraction)
return fraction
def main(_):
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
config = get_config(FLAGS) or FLAGS
if config.env_type == 'simple':
env = SimpleGymEnvironment(config)
else:
env = GymEnvironment(config)
if not tf.test.is_gpu_available() and FLAGS.use_gpu:
raise Exception("use_gpu flag is true when no GPUs are available")
if not FLAGS.use_gpu:
config.cnn_format = 'NHWC'
agent = Agent(config, env, sess)
if FLAGS.is_train:
agent.train()
else:
agent.play()
if __name__ == '__main__':
tf.app.run()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/publiagent123/DQN-tensorflow.git
git@gitee.com:publiagent123/DQN-tensorflow.git
publiagent123
DQN-tensorflow
DQN-tensorflow
master

搜索帮助