Ai
1 Star 0 Fork 0

夏召强/ParameterFreeRCNs-MicroExpressionRec

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
MeNets_NAS.py 39.63 KB
一键复制 编辑 原始数据 按行查看 历史
夏召强 提交于 2021-12-31 11:11 +08:00 . The first shared version.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
import torch
import torch.nn as nn
from torch.nn import functional as F
__all__ = ['MeNet_A', 'MeNet_D', 'MeNet_W', 'MeNet_H', 'MeNet_C', 'MeNet_E']
class ConvBlock(nn.Module):
"""convolutional layer blocks for sequtial convolution operations"""
def __init__(self, in_features, out_features, num_conv, pool=False):
super(ConvBlock, self).__init__()
features = [in_features] + [out_features for i in range(num_conv)]
layers = []
for i in range(len(features)-1):
layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True))
layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True))
layers.append(nn.ReLU())
if pool:
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
self.op = nn.Sequential(*layers)
def forward(self, x):
return self.op(x)
class RclBlock(nn.Module):
"""recurrent convolutional blocks"""
def __init__(self, inplanes, planes):
super(RclBlock, self).__init__()
self.ffconv = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.rrconv = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout()
)
def forward(self, x):
y = self.ffconv(x)
y = self.rrconv(x + y)
y = self.rrconv(x + y)
out = self.downsample (y)
return out
class DenseBlock(nn.Module):
"""densely connected convolutional blocks"""
def __init__(self, inplanes, planes):
super(DenseBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout()
)
def forward(self, x):
y = self.conv1(x)
z = self.conv2(x + y)
# out = self.conv2(x + y + z)
e = self.conv2(x + y + z)
out = self.conv2(x + y + z + e)
out = self.downsample (out)
return out
class EmbeddingBlock(nn.Module):
"""densely connected convolutional blocks for embedding"""
def __init__(self, inplanes, planes):
super(EmbeddingBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.attenmap = SpatialAttentionBlock_P(normalize_attn=True)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=3, padding=0),
nn.Dropout()
)
def forward(self, x, w, pool_size, classes):
y = self.conv1(x)
y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes)
y = torch.mul(F.interpolate(y1, (y.shape[2], y.shape[3])), y)
z = self.conv2(x+y)
e = self.conv2(x + y + z)
out = self.conv2(x + y + z + e)
out = self.downsample (out)
return out
class EmbeddingBlock2(nn.Module):
"""densely connected convolutional blocks for embedding"""
def __init__(self, inplanes, planes):
super(EmbeddingBlock2, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
self.attenmap = SpatialAttentionBlock_P(normalize_attn=True)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=3, padding=0),
nn.Dropout()
)
def forward(self, x, w, pool_size, classes):
y = self.conv1(x)
#y1 = self.attenmap(F.adaptive_avg_pool2d(x, (pool_size, pool_size)), w, classes)
#y = torch.mul(F.interpolate(y1, (y.shape[2], y.shape[3])), y)
z = self.conv2(y)
e = self.conv2(y + z)
out = self.conv2(y + z + e)
out = self.downsample (out)
return out
class SpatialAttentionBlock_A(nn.Module):
"""linear attention block for any layers"""
def __init__(self, in_features, normalize_attn=True):
super(SpatialAttentionBlock_A, self).__init__()
self.normalize_attn = normalize_attn
self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
def forward(self, l):
N, C, W, H = l.size()
c = self.op(l) # batch_sizex1xWxH
if self.normalize_attn:
a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
else:
a = torch.sigmoid(c)
g = torch.mul(a.expand_as(l), l)
return g
class SpatialAttentionBlock_P(nn.Module):
"""linear attention block for any layers"""
def __init__(self, normalize_attn=True):
super(SpatialAttentionBlock_P, self).__init__()
self.normalize_attn = normalize_attn
def forward(self, l, w, classes):
output_cam = []
for idx in range(0,classes):
weights = w[idx,:].reshape((l.shape[1], l.shape[2], l.shape[3]))
cam = weights * l
cam = cam.mean(dim=1,keepdim=True)
cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0]
cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0]
output_cam.append(cam)
output = torch.cat(output_cam, dim=1)
output = output.mean(dim=1,keepdim=True)
return output
class SpatialAttentionBlock_F(nn.Module):
"""linear attention block for first layer"""
def __init__(self, normalize_attn=True):
super(SpatialAttentionBlock_F, self).__init__()
self.normalize_attn = normalize_attn
def forward(self, l, w, classes):
output_cam = []
for idx in range(0,classes):
weights = w[idx,:].reshape((-1, l.shape[2], l.shape[3]))
weights = weights.mean(dim=0,keepdim=True)
cam = weights * l
cam = cam.mean(dim=1,keepdim=True)
cam = cam - torch.min(torch.min(cam,3,True)[0],2,True)[0]
cam = cam / torch.max(torch.max(cam,3,True)[0],2,True)[0]
output_cam.append(cam)
output = torch.cat(output_cam, dim=1)
output = output.mean(dim=1,keepdim=True)
return output
def MakeLayer(block, planes, blocks):
layers = []
for _ in range(0, blocks):
layers.append(block(planes, planes))
return nn.Sequential(*layers)
class MeNet_A(nn.Module):
"""menet networks with adding attention unit
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5, model_version=3):
super(MeNet_A, self).__init__()
self.version = model_version
self.classes = num_classes
self.conv1 = nn.Sequential(
nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout(),
)
self.rcls = MakeLayer(RclBlock, featuremaps, num_layers)
self.attenmap = SpatialAttentionBlock_P(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
if self.version == 1:
x = self.conv1(x)
x = self.attenmap(x)
x = self.rcls(x)
x = self.avgpool(x)
if self.version == 2:
x = self.conv1(x)
x = self.attenmap(x)
x = self.rcls(x)
x = self.avgpool(x)
elif self.version == 3:
x = self.conv1(x)
y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x = self.rcls(x)
x = self.avgpool(x)
x = x * y
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_D(nn.Module):
"""menet networks with dense connection
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_D, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(num_input, featuremaps, kernel_size=5, stride=1, padding=0),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout(),
)
self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers)
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv1(x)
x = self.dbl(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_W(nn.Module):
"""menet networks with wide expansion
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_W, self).__init__()
num_channels = int(featuremaps/2)
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2),
nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=2, dilation=2), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(int(num_channels/2)),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream3 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=5, stride=3, padding=2),
nn.Conv2d(num_input, int(num_channels/2), kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(int(num_channels/2)),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.rcls = MakeLayer(RclBlock, featuremaps, num_layers)
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.stream1(x)
x2 = self.stream2(x)
x3 = self.stream3(x)
x = torch.cat((x1,x2,x3),1)
x = self.rcls(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_H(nn.Module):
"""menet networks with hybrid modules
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_H, self).__init__()
self.classes = num_classes
num_channels = int(featuremaps/2)
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers)
self.rcls = MakeLayer(RclBlock, featuremaps, num_layers)
self.attenmap = SpatialAttentionBlock_P(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x = self.dbl(x)
x = self.avgpool(x)
x = x * y
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_CS(nn.Module):
"""menet networks with cascaded modules with searching
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_CS, self).__init__()
self.classes = num_classes
num_channels = int(featuremaps/2)
self.archi = nn.Parameter(torch.randn(2,2))
# self.stream1 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.stream2 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.conv1 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout()
)
self.softmax = nn.Softmax(0)
self.attenmap = SpatialAttentionBlock_F(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
nn.init.constant(self.archi, 0.5)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
W = self.softmax(self.archi)
#W = self.archi
#M for attention mask
M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
x1 = torch.mul(F.interpolate(M1,(x.shape[2],x.shape[3])), x)
#x = W[0][0]*x+W[0][1]*x1
x = x+W[0][1]*x1
#Second Ateention
M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
y = self.conv1(x)
y1 = torch.mul(F.interpolate(M2,(y.shape[2],y.shape[3])), y)
#y = W[1][0]*y + W[1][1]*y1
y = y + W[1][1]*y1
#Third Ateention
M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes)
z = self.conv2(x+y)
z1 = torch.mul(F.interpolate(M3,(z.shape[2],z.shape[3])), z)
#z = W[2][0]*z + W[2][1]*z1
z = z #+ W[2][1]*z1
#Forth Ateention
M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes)
e = self.conv2(x+y+z)
e1 = torch.mul(F.interpolate(M4,(e.shape[2],e.shape[3])), e)
e = e #+W[3][1]*e1
#e = W[3][0]*e+W[3][1]*e1
#Fiveth Ateention
M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes)
out = self.conv2(x+y+z+e)
out1 = torch.mul(F.interpolate(M5,(out.shape[2],out.shape[3])), out)
#out = W[4][0]*out+W[4][1]*out1
out = out #+W[4][1]*out1
out = self.downsample(out)
x = self.avgpool(out)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_CS2(nn.Module):
"""menet networks with cascaded modules with searching
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_CS2, self).__init__()
self.classes = num_classes
num_channels = int(featuremaps/2)
self.archi = nn.Parameter(torch.randn(4,2))
# self.stream1 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.stream2 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.conv1 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout()
)
self.softmax = nn.Softmax(-1)
self.attenmap = SpatialAttentionBlock_F(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
nn.init.constant(self.archi, 0.5)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
W = self.softmax(self.archi)
#W = self.archi
#M for attention mask
M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
#Second Ateention
M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
y = self.conv1(x)
y1 = torch.mul(F.interpolate(M1,(y.shape[2],y.shape[3])), y)# Here we use M1
y = W[0][0]*y + W[0][1]*y1
#Third Ateention
M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes)
z = self.conv2(x+y)
z1 = torch.mul(F.interpolate(M2,(z.shape[2],z.shape[3])), z)
z = W[1][0]*z + W[1][1]*z1
#z = z #+ W[2][1]*z1
#Forth Ateention
M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes)
e = self.conv2(x+y+z)
e1 = torch.mul(F.interpolate(M3,(e.shape[2],e.shape[3])), e)
e = W[2][0]*e+W[2][1]*e1
#Fiveth Ateention
#M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes)
out = self.conv2(x+y+z+e)
out1 = torch.mul(F.interpolate(M4,(out.shape[2],out.shape[3])), out)
out = W[3][0]*out+W[3][1]*out1
#out = out #+W[4][1]*out1
out = self.downsample(out)
x = self.avgpool(out)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_CS3(nn.Module):
"""menet networks with cascaded modules with searching
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_CS3, self).__init__()
self.classes = num_classes
num_channels = int(featuremaps/2)
self.archi = nn.Parameter(torch.randn(3,2))
# self.stream1 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.stream2 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.conv1 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(featuremaps, featuremaps, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(featuremaps),
nn.ReLU(inplace=True)
)
self.downsample = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Dropout()
)
self.softmax = nn.Softmax(-1)
self.attenmap = SpatialAttentionBlock_F(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
nn.init.constant(self.archi, 0.5)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
W = self.softmax(self.archi)
#W = self.archi
#M for attention mask
M1 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
#Second Ateention
M2 = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
y = self.conv1(x)
#Third Ateention
M3 = self.attenmap(self.downsampling(y), self.classifier.weight, self.classes)
z = self.conv2(x+y)
z1 = torch.mul(F.interpolate(M1,(z.shape[2],z.shape[3])), z)
z = W[0][0]*z + W[0][1]*z1
#z = z #+ W[2][1]*z1
#Forth Ateention
#M4 = self.attenmap(self.downsampling(z), self.classifier.weight, self.classes)
e = self.conv2(x+y+z)
e1 = torch.mul(F.interpolate(M2,(e.shape[2],e.shape[3])), e)
e = W[1][0]*e+W[1][1]*e1
#Fiveth Ateention
#M5 = self.attenmap(self.downsampling(e), self.classifier.weight, self.classes)
out = self.conv2(x+y+z+e)
out1 = torch.mul(F.interpolate(M3,(out.shape[2],out.shape[3])), out)
out = W[2][0]*out+W[2][1]*out1
#out = out #+W[4][1]*out1
out = self.downsample(out)
x = self.avgpool(out)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_C(nn.Module):
"""menet networks with cascaded modules
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_C, self).__init__()
self.classes = num_classes
num_channels = int(featuremaps/2)
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=1, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=5, stride=1, padding=2), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.dbl = MakeLayer(DenseBlock, featuremaps, num_layers)
# self.attenmap = SpatialAttentionBlock_P(normalize_attn=True)
self.attenmap = SpatialAttentionBlock_F(normalize_attn=True)
self.downsampling = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
# y = self.attenmap(self.downsampling(x), self.classifier.weight, self.classes)
x = torch.mul(F.interpolate(y,(x.shape[2],x.shape[3])), x)
x = self.dbl(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
class MeNet_E(nn.Module):
"""menet networks with embedded modules
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_E, self).__init__()
self.classes = num_classes
self.poolsize = pool_size
num_channels = int(featuremaps/2)
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.ebl = EmbeddingBlock(featuremaps, featuremaps)
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.stream1(x)
x2 = self.stream2(x)
x = torch.cat((x1,x2),1)
x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# class MeNet_ES(nn.Module):
# """menet networks with embedded modules by searching
# """
# def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
# super(MeNet_ES, self).__init__()
# self.classes = num_classes
# self.poolsize = pool_size
# num_channels = featuremaps
# self.archi = nn.Parameter(torch.randn(3))
# self.stream1 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# # nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.stream2 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1, dilation=2), # 5,2/ 1,0
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# # nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.stream3 = nn.Sequential(
# nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
# nn.ReLU(inplace=True),
# nn.BatchNorm2d(num_channels),
# # nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
# nn.Dropout(),
# )
# self.bn = nn.BatchNorm2d(num_channels)
# self.softmax = nn.Softmax(0)
# self.ebl = EmbeddingBlock(featuremaps, featuremaps)
# self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
# self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
# nn.init.constant(self.archi, 0.333)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.Linear):
# nn.init.normal_(m.weight, 0, 0.01)
# nn.init.constant_(m.bias, 0)
# def forward(self, x):
# x1 = self.stream1(x)
# x2 = self.stream2(x)
# x3 = self.stream3(x)
# #x = torch.cat((x1,x2,x3),1)
# W = self.softmax(self.archi)
# #print(W)
# x = W[0]*x1+ W[1]*x2+ W[2]*x3
# x = self.bn(x)
# x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.classifier(x)
# return x
class MeNet_ES(nn.Module):
"""menet networks with embedded modules by searching
"""
def __init__(self, num_input, featuremaps, num_classes, num_layers=1, pool_size=5):
super(MeNet_ES, self).__init__()
self.classes = num_classes
self.poolsize = pool_size
num_channels = int(featuremaps/2)
self.archi = nn.Parameter(torch.randn(2))
self.stream1 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream2 = nn.Sequential(
nn.Conv2d(num_input, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream21 = nn.Sequential(
nn.Conv2d(featuremaps, num_channels, kernel_size=3, stride=3, padding=1), # 1->3
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
self.stream22 = nn.Sequential(
nn.Conv2d(featuremaps, num_channels, kernel_size=3, stride=3, padding=3, dilation=3), # 5,2/ 1,0
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_channels),
# nn.MaxPool2d(kernel_size=3, stride=3, padding=1),
nn.Dropout(),
)
#self.bn = nn.BatchNorm2d(num_channels)
self.softmax = nn.Softmax(0)
self.ebl = EmbeddingBlock2(num_input, featuremaps)
self.ebl2 = EmbeddingBlock(featuremaps, featuremaps)
self.avgpool = nn.AdaptiveAvgPool2d((pool_size, pool_size))
self.classifier = nn.Linear(pool_size*pool_size*featuremaps, num_classes)
nn.init.constant(self.archi, 0.5)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.stream1(x)
x2 = self.stream2(x)
x = self.ebl(x, self.classifier.weight, self.poolsize, self.classes)
x1 = torch.cat((x1,x2),1)
W = self.softmax(self.archi)
x = W[0]*x+ W[1]*x1
#
y = self.ebl2(x, self.classifier.weight, self.poolsize, self.classes)
y1 = self.stream21(x)
y2 = self.stream22(x)
y1 = torch.cat((y1,y2),1)
x = W[1]*y+ W[0]*y1
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
print(W)
return x
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xia_zhaoqiang/ParameterFreeRCNs-MicroExpressionRec.git
git@gitee.com:xia_zhaoqiang/ParameterFreeRCNs-MicroExpressionRec.git
xia_zhaoqiang
ParameterFreeRCNs-MicroExpressionRec
ParameterFreeRCNs-MicroExpressionRec
master

搜索帮助