代码拉取完成,页面将自动刷新
同步操作将从 越努力越幸运/Pytorch-Medical-Segmentation(2D 3D) 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from torch import nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def cross_entropy_3D(input, target, weight=None, size_average=True):
n, c, h, w, s = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)
target = target.view(target.numel())
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= float(target.numel())
return loss
class Binary_Loss(nn.Module):
def __init__(self):
super(Binary_Loss, self).__init__()
self.criterion = nn.BCEWithLogitsLoss()
def forward(self, model_output, targets):
#targets[targets == 0] = -1
# torch.empty(3, dtype=torch.long)
# model_output = model_output.long()
# targets = targets.long()
# print(model_output)
# print(F.sigmoid(model_output))
# print(targets)
# print('kkk')
# model_output =torch.LongTensor(model_output.cpu())
# targets =torch.LongTensor(targets.cpu())
# model_output = model_output.type(torch.LongTensor)
# targets = targets.type(torch.LongTensor)
loss = self.criterion(model_output, targets)
return loss
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
class BinaryDiceLoss(nn.Module):
"""Dice loss of binary class
Args:
smooth: A float number to smooth loss, and avoid NaN error, default: 1
p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
predict: A tensor of shape [N, *]
target: A tensor of shape same with predict
reduction: Reduction method to apply, return mean over batch if 'mean',
return sum if 'sum', return a tensor of shape [N,] if 'none'
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""
def __init__(self, smooth=1, p=2, reduction='mean'):
super(BinaryDiceLoss, self).__init__()
self.smooth = smooth
self.p = p
self.reduction = reduction
def forward(self, predict, target):
assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
loss = 1 - num / den
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))
class DiceLoss(nn.Module):
"""Dice loss, need one hot encode input
Args:
weight: An array of shape [num_classes,]
ignore_index: class index to ignore
predict: A tensor of shape [N, C, *]
target: A tensor of same shape with predict
other args pass to BinaryDiceLoss
Return:
same as BinaryDiceLoss
"""
def __init__(self, weight=None, ignore_index=None, **kwargs):
super(DiceLoss, self).__init__()
self.kwargs = kwargs
self.weight = weight
self.ignore_index = ignore_index
def forward(self, predict, target):
assert predict.shape == target.shape, 'predict & target shape do not match'
dice = BinaryDiceLoss(**self.kwargs)
total_loss = 0
predict = F.softmax(predict, dim=1)
for i in range(target.shape[1]):
if i != self.ignore_index:
dice_loss = dice(predict[:, i], target[:, i])
if self.weight is not None:
assert self.weight.shape[0] == target.shape[1], \
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/target.shape[1]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。