Skip to content

Commit 09fe20a

Browse files
authored
Add BCE when no annotations
Correction of focal loss
1 parent 89cce50 commit 09fe20a

1 file changed

Lines changed: 27 additions & 3 deletions

File tree

efficientdet/loss.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,36 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5252
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
5353

5454
if bbox_annotation.shape[0] == 0:
55-
if torch.cuda.is_available():
55+
if torch.cuda.is_available()
56+
57+
targets = torch.zeros_like(classification)
58+
alpha_factor = torch.ones_like(targets) * alpha
59+
targets = targets.cuda()
60+
alpha_factor = alpha_factor.cuda()
61+
alphe_factot = 1. - alpha_factor
62+
focal_weight = classification
63+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
64+
65+
bce = -((1.0 - targets) * torch.log(1.0 - classification))
66+
67+
cls_loss = focal_weight * bce
68+
5669
regression_losses.append(torch.tensor(0).to(dtype).cuda())
57-
classification_losses.append(torch.tensor(0).to(dtype).cuda())
70+
classification_losses.append(cls_loss.sum(), min=1.0))
5871
else:
72+
73+
targets = torch.zeros_like(classification)
74+
alpha_factor = torch.ones_like(targets) * alpha
75+
alphe_factot = 1. - alpha_factor
76+
focal_weight = classification
77+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
78+
79+
bce = -((1.0 - targets) * torch.log(1.0 - classification))
80+
81+
cls_loss = focal_weight * bce
82+
5983
regression_losses.append(torch.tensor(0).to(dtype))
60-
classification_losses.append(torch.tensor(0).to(dtype))
84+
classification_losses.append(cls_loss.sum(), min=1.0))
6185

6286
continue
6387

0 commit comments

Comments
 (0)