1 Star 2 Fork 0

zweien/graph-pde

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
nn_conv.py 11.15 KB
一键复制 编辑 原始数据 按行查看 历史
Zongyi Li 提交于 2020-01-05 13:59 +08:00 . Add files via upload
import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset, uniform
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class NNConv(MessagePassing):
r"""The continuous kernel-based convolutional operator from the
`"Neural Message Passing for Quantum Chemistry"
<https://arxiv.org/abs/1704.01212>`_ paper.
This convolution is also known as the edge-conditioned convolution from the
`"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
:class:`torch_geometric.nn.conv.ECConv` for an alias):
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
a MLP.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
maps edge features :obj:`edge_attr` of shape :obj:`[-1,
num_edge_features]` to shape
:obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`.
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output.
(default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
in_channels,
out_channels,
nn,
aggr='add',
root_weight=True,
bias=True,
**kwargs):
super(NNConv, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.nn = nn
self.aggr = aggr
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
reset(self.nn)
uniform(self.in_channels, self.root)
uniform(self.in_channels, self.bias)
def forward(self, x, edge_index, edge_attr):
""""""
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
return self.propagate(edge_index, x=x, pseudo=pseudo)
def message(self, x_j, pseudo):
weight_diag = torch.diag_embed(self.nn(pseudo)).view(-1, self.in_channels, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight_diag).squeeze(1)
def update(self, aggr_out, x):
if self.root is not None:
aggr_out = aggr_out + torch.mm(x, self.root)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
class NNConv_Gaussian(MessagePassing):
r"""The continuous kernel-based convolutional operator from the
`"Neural Message Passing for Quantum Chemistry"
<https://arxiv.org/abs/1704.01212>`_ paper.
This convolution is also known as the edge-conditioned convolution from the
`"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
:class:`torch_geometric.nn.conv.ECConv` for an alias):
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
a MLP.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
maps edge features :obj:`edge_attr` of shape :obj:`[-1,
num_edge_features]` to shape
:obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`.
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output.
(default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
in_channels,
out_channels,
nn,
aggr='add',
root_weight=True,
bias=True,
**kwargs):
super(NNConv_Gaussian, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.nn = nn
self.aggr = aggr
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
reset(self.nn)
uniform(self.in_channels, self.root)
uniform(self.in_channels, self.bias)
def forward(self, x, edge_index, edge_attr):
""""""
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
return self.propagate(edge_index, x=x, pseudo=pseudo)
def message(self, x_j, pseudo):
one = torch.ones(1).to(device)
a = 1 / torch.sqrt(torch.abs(pseudo[:,1] * pseudo[:,2]))
# print('a',torch.isnan(a))
b = torch.exp(-1 * (pseudo[:, 0] ** 2).view(-1, 1) / (self.nn(one) ** 2).view(1, -1))
# print('b',torch.isnan(b))
weight_guass = a.reshape(-1,1).repeat(1,64) * b
# print('w',torch.isnan(weight_guass))
weight_guass = torch.diag_embed(weight_guass).view(-1, self.in_channels, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight_guass).squeeze(1)
def update(self, aggr_out, x):
if self.root is not None:
aggr_out = aggr_out + torch.mm(x, self.root)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
class NNConv_old(MessagePassing):
r"""The continuous kernel-based convolutional operator from the
`"Neural Message Passing for Quantum Chemistry"
<https://arxiv.org/abs/1704.01212>`_ paper.
This convolution is also known as the edge-conditioned convolution from the
`"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
:class:`torch_geometric.nn.conv.ECConv` for an alias):
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),
where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
a MLP.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
maps edge features :obj:`edge_attr` of shape :obj:`[-1,
num_edge_features]` to shape
:obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`.
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output.
(default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
in_channels,
out_channels,
nn,
aggr='add',
root_weight=True,
bias=True,
**kwargs):
super(NNConv_old, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.nn = nn
self.aggr = aggr
if root_weight:
self.root = Parameter(torch.Tensor(in_channels, out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
reset(self.nn)
size = self.in_channels
uniform(size, self.root)
uniform(size, self.bias)
def forward(self, x, edge_index, edge_attr):
""""""
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
return self.propagate(edge_index, x=x, pseudo=pseudo)
def message(self, x_j, pseudo):
weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)
def update(self, aggr_out, x):
if self.root is not None:
aggr_out = aggr_out + torch.mm(x, self.root)
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
ECConv = NNConv
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zweien/graph-pde.git
git@gitee.com:zweien/graph-pde.git
zweien
graph-pde
graph-pde
master

搜索帮助