-
Notifications
You must be signed in to change notification settings - Fork 15
/
criteria.py
59 lines (42 loc) · 1.46 KB
/
criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-
# @Time : 2018/10/23 20:04
# @Author : Wang Xin
# @Email : wangxin_buaa@163.com
import torch
import torch.nn as nn
class MaskedMSELoss(nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = (diff ** 2).mean()
return self.loss
class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
self.loss = diff.abs().mean()
return self.loss
class berHuLoss(nn.Module):
def __init__(self):
super(berHuLoss, self).__init__()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
huber_c = torch.max(pred - target)
huber_c = 0.2 * huber_c
valid_mask = (target > 0).detach()
diff = target - pred
diff = diff[valid_mask]
diff = diff.abs()
huber_mask = (diff > huber_c).detach()
diff2 = diff[huber_mask]
diff2 = diff2 ** 2
self.loss = torch.cat((diff, diff2)).mean()
return self.loss