1 Star 0 Fork 0

张志阳/brainjs-DQN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
controller.js 13.47 KB
一键复制 编辑 原始数据 按行查看 历史
张志阳 提交于 2024-11-16 13:29 +08:00 . 初始化项目
/**
* 游戏控制器类
* 负责协调游戏、AI和UI之间的交互
*/
class GameController {
/**
* 构造函数
* 等待DOM加载完成后初始化控制器
*/
constructor() {
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', () => this.init());
} else {
this.init();
}
}
/**
* 初始化控制器
* 创建游戏、AI和存储实例,设置初始状态
*/
init() {
this.game = new Game();
this.ai = new GameAI();
this.storage = new ModelStorage();
this.isTraining = false; // 训练状态标志
this.isTesting = false; // 测试状态标志
this.isPaused = false; // 暂停状态标志
this.frameCount = 0; // 帧计数器
this.episodeCount = 0; // 训练回合计数
this.lossHistory = []; // 损失历史记录
this.scoreHistory = []; // 得分历史记录
this.lastScore = 0; // 上一次的得分
this.setupCharts(); // 设置图表
this.bindEvents(); // 绑定事件
this.updateButtonStates(); // 更新按钮状态
console.log('Controller initialized');
}
/**
* 设置图表
* 创建损失和得分的可视化图表
*/
setupCharts() {
try {
// 设置损失图表
const lossCtx = document.getElementById('lossChart').getContext('2d');
this.lossChart = new Chart(lossCtx, {
type: 'line',
data: {
labels: [],
datasets: [{
label: '训练损失',
data: [],
borderColor: 'rgb(75, 192, 192)',
tension: 0.1
}]
},
options: {
responsive: true,
maintainAspectRatio: false
}
});
// 设置得分图表
const scoreCtx = document.getElementById('scoreChart').getContext('2d');
this.scoreChart = new Chart(scoreCtx, {
type: 'line',
data: {
labels: [],
datasets: [{
label: '游戏得分',
data: [],
borderColor: 'rgb(255, 99, 132)',
tension: 0.1
}]
},
options: {
responsive: true,
maintainAspectRatio: false
}
});
console.log('Charts setup completed');
} catch (error) {
console.error('Charts setup failed:', error);
}
}
/**
* 绑定事件处理器
* 为所有按钮添加点击事件监听
*/
bindEvents() {
try {
document.getElementById('startTraining').onclick = () => this.startTraining();
document.getElementById('startTesting').onclick = () => this.startTesting();
document.getElementById('speedUpTraining').onclick = () => this.speedUpTraining();
document.getElementById('pauseTraining').onclick = () => this.pauseTraining();
document.getElementById('resumeTraining').onclick = () => this.resumeTraining();
document.getElementById('saveModel').onclick = () => this.saveModel();
document.getElementById('loadModel').onclick = () => this.showModelLoadDialog();
console.log('Events bound successfully');
} catch (error) {
console.error('Event binding failed:', error);
}
}
/**
* 更新按钮状态
* 根据当前游戏状态启用/禁用相应按钮
*/
updateButtonStates() {
const isRunning = this.isTraining || this.isTesting;
const hasModel = localStorage.getItem(`${this.storage.storageKey}_latest`) ||
localStorage.getItem(`${this.storage.storageKey}_best`);
try {
// 更新各个按钮的状态
document.getElementById('startTraining').disabled = isRunning;
document.getElementById('startTesting').disabled = isRunning || !hasModel;
document.getElementById('speedUpTraining').disabled = !isRunning;
document.getElementById('pauseTraining').disabled = !isRunning || this.isPaused;
document.getElementById('resumeTraining').disabled = !isRunning || !this.isPaused;
document.getElementById('saveModel').disabled = !this.isTraining || this.isPaused || this.isTesting;
document.getElementById('loadModel').disabled = isRunning;
// 更新按钮样式
document.querySelectorAll('.button-group button').forEach(button => {
if (button.disabled) {
button.classList.add('disabled');
} else {
button.classList.remove('disabled');
}
});
} catch (error) {
console.error('更新按钮状态失败:', error);
}
}
/**
* 游戏主循环
* 处理游戏状态更新、AI决策和训练
*/
async gameLoop() {
if ((!this.isTraining && !this.isTesting) || this.isPaused) return;
try {
const state = this.game.getState();
let action;
// 根据模式选择动作
if (this.isTesting) {
action = await this.ai.predict(state);
} else {
if (Math.random() < this.ai.epsilon) {
action = Math.floor(Math.random() * 4);
} else {
action = await this.ai.predict(state);
}
}
// 执行动作并更新游戏状态
this.game.move(action);
this.game.checkCollisions();
// 训练模式下进行学习
if (!this.isTesting) {
const reward = this.calculateReward();
const nextState = this.game.getState();
const done = this.isEpisodeEnd();
this.ai.remember(state, action, reward, nextState, done);
if (this.frameCount % 4 === 0) {
const history = await this.ai.train();
if (history) {
this.updateVisualization(history);
}
}
}
// 更新显示
this.game.draw();
this.updateProgress();
// 检查回合是否结束
if (this.isEpisodeEnd()) {
this.handleEpisodeEnd();
}
this.frameCount++;
requestAnimationFrame(() => this.gameLoop());
} catch (error) {
console.error('游戏循环错误:', error);
this.handleError(error);
}
}
/**
* 计算奖励值
* @returns {number} 当前状态的奖励值
*/
calculateReward() {
let reward = 0;
reward += this.game.score - this.lastScore; // 分数变化
this.lastScore = this.game.score;
reward += 0.1; // 存活奖励
return reward;
}
/**
* 判断当前回合是否结束
* @returns {boolean} 是否结束
*/
isEpisodeEnd() {
return this.game.score < -50 || this.frameCount > 1000;
}
/**
* 处理回合结束
* 保存模型、更新状态、重置游戏
*/
handleEpisodeEnd() {
this.episodeCount++;
this.scoreHistory.push(this.game.score);
// 保存模型
this.storage.saveModel(this.ai, 'latest');
// 如果是最高分,保存最佳模型
if (this.game.score > this.ai.bestScore) {
this.ai.bestScore = this.game.score;
const currentAI = this.ai;
this.storage.saveModel(currentAI, 'best').then(() => {
console.log('最佳模型已保存,分数:', this.game.score);
});
}
// 重置状态
this.game = new Game();
this.ai = new GameAI();
this.frameCount = 0;
this.lastScore = 0;
}
/**
* 开始训练模式
*/
async startTraining() {
if (!this.isTraining) {
this.isTraining = true;
this.isTesting = false;
this.isPaused = false;
this.updateButtonStates();
this.gameLoop();
}
}
/**
* 开始测试模式
*/
async startTesting() {
if (!this.isTesting) {
this.isTesting = true;
this.isTraining = false;
this.isPaused = false;
this.ai.epsilon = 0; // 测试模式下不使用探索
this.updateButtonStates();
this.gameLoop();
}
}
/**
* 加速训练/测试
*/
speedUpTraining() {
this.game.gameSpeed = this.game.gameSpeed >= 8 ? 1 : this.game.gameSpeed * 2;
this.updateButtonStates();
}
/**
* 暂停训练/测试
*/
pauseTraining() {
this.isPaused = true;
this.updateButtonStates();
}
/**
* 恢复训练/测试
*/
resumeTraining() {
if (this.isPaused) {
this.isPaused = false;
this.updateButtonStates();
this.gameLoop();
}
}
/**
* 保存当前模型
*/
async saveModel() {
try {
await this.storage.saveModel(this.ai, 'latest');
if (this.game.score >= this.ai.bestScore) {
await this.storage.saveModel(this.ai, 'best');
console.log('保存为最佳模型,分数:', this.game.score);
}
alert('模型保存成功!');
} catch (error) {
console.error('保存模型失败:', error);
alert('保存模型失败:' + error.message);
}
}
/**
* 显示模型加载对话框
*/
showModelLoadDialog() {
const dialog = document.createElement('div');
dialog.className = 'model-load-dialog';
dialog.innerHTML = `
<div class="dialog-content">
<h3>选择要加载的模型</h3>
<button onclick="controller.loadModel('latest')">加载最新模型</button>
<button onclick="controller.loadModel('best')">加载最佳模型</button>
<button onclick="this.parentElement.parentElement.remove()">取消</button>
</div>
`;
document.body.appendChild(dialog);
}
/**
* 加载模型
* @param {string} type - 模型类型('latest'或'best')
*/
async loadModel(type) {
try {
const modelData = await this.storage.loadModel(type);
this.ai = new GameAI();
this.ai.fromJSON(modelData);
this.game = new Game();
this.frameCount = 0;
this.lastScore = 0;
// 重置所有状态
this.isTraining = false;
this.isTesting = false;
this.isPaused = false;
this.updateButtonStates();
alert('模型加载成功!');
} catch (error) {
console.error('加载模型失败:', error);
alert('加载模型失败:' + error.message);
}
document.querySelector('.model-load-dialog').remove();
}
/**
* 更新可视化图表
* @param {Object} history - 训练历史数据
*/
updateVisualization(history) {
if (!history) return;
// 更新损失图表
const loss = history.error || 0;
this.lossHistory.push(loss);
this.lossChart.data.labels = this.lossHistory.map((_, i) => i);
this.lossChart.data.datasets[0].data = this.lossHistory;
this.lossChart.update();
// 更新得分图表
this.scoreChart.data.labels = this.scoreHistory.map((_, i) => i);
this.scoreChart.data.datasets[0].data = this.scoreHistory;
this.scoreChart.update();
}
/**
* 更新进度显示
*/
updateProgress() {
try {
document.getElementById('episodeCount').textContent = this.episodeCount;
document.getElementById('currentScore').textContent = this.game.score;
document.getElementById('bestScore').textContent = this.ai.bestScore;
document.getElementById('epsilon').textContent = this.ai.epsilon.toFixed(3);
document.getElementById('gameSpeed').textContent = this.game.gameSpeed + 'x';
document.getElementById('avgLoss').textContent =
this.lossHistory.length > 0
? (this.lossHistory.reduce((a, b) => a + b, 0) / this.lossHistory.length).toFixed(4)
: '0.0000';
document.getElementById('memorySize').textContent = this.ai.memory.length;
document.getElementById('trainingCount').textContent = this.ai.trainingCount;
} catch (error) {
console.error('更新进度显示失败:', error);
}
}
/**
* 处理错误
* @param {Error} error - 错误对象
*/
handleError(error) {
this.isPaused = true;
this.updateButtonStates();
alert('训练过程出现错误,已自动暂停。错误信息: ' + error.message);
}
}
// 创建控制器实例
window.addEventListener('load', () => {
window.controller = new GameController();
});
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhang-zhiyang/brainjs-dqn.git
git@gitee.com:zhang-zhiyang/brainjs-dqn.git
zhang-zhiyang
brainjs-dqn
brainjs-DQN
master

搜索帮助