Ai
1 Star 0 Fork 0

chenglijie1015/ICM-PPO-implementation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ActorCritic.py 1.72 KB
一键复制 编辑 原始数据 按行查看 历史
Stepan-Makarenko 提交于 2020-05-22 02:26 +08:00 . GPU
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, n_latent_var, activation=nn.Tanh(), device='cpu'):
super(ActorCritic, self).__init__()
self.device = device
self.body = nn.Sequential(
nn.Linear(state_dim, n_latent_var),
activation,
nn.Linear(n_latent_var, n_latent_var),
activation
).to(self.device)
# Actor head
self.action_layer = nn.Sequential(
self.body,
nn.Linear(n_latent_var, action_dim),
nn.Softmax(dim=-1)
).to(self.device)
# Critic head
self.value_layer = nn.Sequential(
self.body,
nn.Linear(n_latent_var, 1)
).to(self.device)
def forward(self):
raise NotImplementedError
def act(self, state, memory):
# Receive numpy array
state = torch.from_numpy(state).float().to(self.device)
action_probs = self.action_layer(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
# Return numpy array
return action.cpu().numpy()
def evaluate(self, state, action):
action_probs = self.action_layer(state.to(self.device))
dist = Categorical(action_probs)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_value = self.value_layer(state.to(self.device))
return action_logprobs, torch.squeeze(state_value), dist_entropy
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chenglijie1015/ICM-PPO-implementation.git
git@gitee.com:chenglijie1015/ICM-PPO-implementation.git
chenglijie1015
ICM-PPO-implementation
ICM-PPO-implementation
master

搜索帮助