代码拉取完成,页面将自动刷新
同步操作将从 zhanghy/grnn 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
import torch.nn as nn
class AttrProxy(object):
"""
Translates index lookups into attribute lookups.
"""
def __init__(self, module, prefix):
self.module = module
self.prefix = prefix
def __getitem__(self, i):
return getattr(self.module, self.prefix + str(i))
class gruCell(nn.Module):
def __init__(self, opt):
super(gruCell, self).__init__()
self.dimFeature = opt.dimFeature # d
self.dimHidden = opt.dimHidden # D
self.resetGate = nn.Sequential(
nn.Linear(self.dimHidden + self.dimFeature, self.dimHidden),
nn.Sigmoid()
)
self.updateGate = nn.Sequential(
nn.Linear(self.dimHidden + self.dimFeature, self.dimHidden),
nn.Sigmoid()
)
self.transform = nn.Sequential(
nn.Linear(self.dimHidden + self.dimFeature, self.dimHidden),
nn.Tanh()
)
self.output = nn.Linear(self.dimHidden, self.dimFeature)
def forward(self, x, hState):
i = torch.cat((hState, x), 1)
z = self.resetGate(i)
r = self.updateGate(i)
jointI = torch.cat((r * hState, x), 1)
hHat = self.transform(jointI)
h = (1 - z) * hState + z * hHat
o = self.output(h)
hState = h
return o, hState
class Propogator(nn.Module):
"""
Gated Propogator for GRNN
Using GRU
"""
def __init__(self, opt):
super(Propogator, self).__init__()
self.batchSize = opt.batchSize # b
self.nNode = opt.nNode # n
self.dimFeature = opt.dimFeature # d
self.dimHidden = opt.dimHidden # D
for i in range(self.nNode):
cell = gruCell(opt)
self.add_module("gruCell_{}".format(i), cell)
self.cells = AttrProxy(self, "gruCell_")
def forward(self, x, hState, A):
O = torch.zeros(self.batchSize, self.nNode, self.dimFeature).double()
H = torch.zeros(self.batchSize, self.dimHidden, self.nNode).double()
S = torch.bmm(hState, A)
for n in range(self.nNode):
O[:, n, :], H[:, :, n] = self.cells[n](x[:, n, :], S[:, :, n])
hState = H
return O, hState
class GRNN(nn.Module):
def __init__(self, opt):
super(GRNN, self).__init__()
self.batchSize = opt.batchSize # b
self.nNode = opt.nNode # n
self.dimFeature = opt.dimFeature # d
self.dimHidden = opt.dimHidden # D
self.interval = opt.truncate # T
self.propogator = Propogator(opt)
def forward(self, x, hState, A):
"""
x: input node features [batchSize, interval, nNode, dimFeature]
hState: hidden state [batchSize, dimHidden, nNode]
A: transfer matrix [nNode, nNode]
"""
O = torch.zeros(self.batchSize, self.interval, self.nNode, self.dimFeature).double()
for t in range(self.interval):
O[:, t, :, :], h = self.propogator(x[:, t, :, :], hState, A)
hState = h
return O, hState
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。