-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_factory.py
83 lines (73 loc) · 3.25 KB
/
loss_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from pywick import losses as ls
import torch
import torch.nn as nn
from configs import *
class DiceFocalLoss(nn.Module):
'''
:param num_classes: number of classes
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
:param weights: (list(), default = [1,1]) Optional weighing (0.0-1.0) of the losses in order of [dice, focal]
'''
def __init__(self, focal_param, weights=[1.0,1.0], **kwargs):
super(DiceFocalLoss, self).__init__()
self.dice = ls.SoftDiceLoss()
self.focal = ls.BinaryFocalLoss(gamma=focal_param)
self.weights = weights
def forward(self, logits, targets):
#return self.dice(logits, targets)
#return self.focal(logits, targets)
return self.weights[0] * self.dice(logits, targets) + \
self.weights[1] * self.focal(logits, targets)
class BCE_Tversky(nn.Module):
def __init__(self, weights=[1.0, 1.0], **kwargs):
super(BCE_Tversky, self).__init__()
self.bce = ls.BCELoss2d()
self.tversky = ls.TverskyLoss(alpha=0.5, beta=0.7)
self.weights = weights
def forward(self, logits, targets):
return self.weights[0] * self.bce(logits, targets) + \
self.weights[1] * self.tversky(logits, targets)
class BCEdicepenalizeborder_Tversky(nn.Module):
def __init__(self, weights=[1.0, 1.0], **kwargs):
super(BCEdicepenalizeborder_Tversky, self).__init__()
self.bce_dice = ls.BCEDicePenalizeBorderLoss()
self.tversky = ls.TverskyLoss(alpha=0.5, beta=0.7)
self.weights = weights
def forward(self, logits, targets):
return self.weights[0] * self.bce_dice(logits, targets) + \
self.weights[1] * self.tversky(logits, targets)
class Loss_Factory:
@staticmethod
def get_loss(loss='bce'):
if loss == 'bce':
return ls.BCELoss2d()
elif loss == 'dice':
return ls.SoftDiceLoss()
elif loss == 'focal':
return ls.BinaryFocalLoss(gamma=0)
elif loss == 'bce_dice':
return ls.BCEDiceLoss()
elif loss == 'bce_dice_focal':
return ls.BCEDiceFocalLoss(focal_param=0.5)
elif loss == 'dice_focal':
return DiceFocalLoss(focal_param=focal_param, weights=[1, 0])
elif loss == "bcedicepenalizeborderloss":
return ls.BCEDicePenalizeBorderLoss()
elif loss == "lovaszsoftmax":
return ls.LovaszSoftmax()
elif loss == "activecontourloss":
return ls.ActiveContourLoss()
elif loss == "tverskyloss":
return ls.TverskyLoss(alpha=0.5, beta=0.7)
elif loss == "focalbinarytverskyloss":
return ls.FocalBinaryTverskyLoss()
elif loss == "poissonloss":
return ls.PoissonLoss()
elif loss == "combobcediceloss":
return ls.ComboBCEDiceLoss()
elif loss == "bce_tversky":
return BCE_Tversky()
elif loss == "bcedicepenalizeborder_tversky":
return BCEdicepenalizeborder_Tversky()