1 Star 1 Fork 1

许文祥/GANs-for-1D-Signal

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
wgan.py 2.10 KB
一键复制 编辑 原始数据 按行查看 历史
LixiangHan 提交于 2020-10-15 17:08 +08:00 . added wgan & wgan-gp
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input 1824
nn.Conv1d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size 912
nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.2, inplace=True),
# state size 456
nn.Conv1d(128, 256, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
# state size 228
nn.Conv1d(256, 512, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
# state size 114
nn.Conv1d(512, 1, kernel_size=114, stride=1, padding=0, bias=False),
)
def forward(self, x, y=None):
x = self.main(x)
return x
class Generator(nn.Module):
def __init__(self, nz):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose1d(nz, 512, 114, 1, 0, bias=False),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.ConvTranspose1d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.ConvTranspose1d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm1d(128),
nn.ReLU(True),
nn.ConvTranspose1d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm1d(64),
nn.ReLU(True),
nn.ConvTranspose1d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
x = self.main(x)
return x
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xuwenxiang/GANs-for-1D-Signal.git
git@gitee.com:xuwenxiang/GANs-for-1D-Signal.git
xuwenxiang
GANs-for-1D-Signal
GANs-for-1D-Signal
main

搜索帮助