Skip to content

Commit 54c5345

Browse files
authored
Merge pull request #234 from rvandeghen/patch-2
Add BCE when no annotations
2 parents 89cce50 + 71a7975 commit 54c5345

1 file changed

Lines changed: 27 additions & 6 deletions

File tree

efficientdet/loss.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,39 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5151
bbox_annotation = annotations[j]
5252
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
5353

54+
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
55+
5456
if bbox_annotation.shape[0] == 0:
55-
if torch.cuda.is_available():
57+
if torch.cuda.is_available()
58+
59+
alpha_factor = torch.ones_like(classification) * alpha
60+
alpha_factor = alpha_factor.cuda()
61+
alpha_factor = 1. - alpha_factor
62+
focal_weight = classification
63+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
64+
65+
bce = -(torch.log(1.0 - classification))
66+
67+
cls_loss = focal_weight * bce
68+
5669
regression_losses.append(torch.tensor(0).to(dtype).cuda())
57-
classification_losses.append(torch.tensor(0).to(dtype).cuda())
70+
classification_losses.append(cls_loss.sum())
5871
else:
72+
73+
alpha_factor = torch.ones_like(classification) * alpha
74+
alpha_factor = 1. - alpha_factor
75+
focal_weight = classification
76+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
77+
78+
bce = -(torch.log(1.0 - classification))
79+
80+
cls_loss = focal_weight * bce
81+
5982
regression_losses.append(torch.tensor(0).to(dtype))
60-
classification_losses.append(torch.tensor(0).to(dtype))
83+
classification_losses.append(cls_loss.sum())
6184

6285
continue
63-
64-
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
65-
86+
6687
IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
6788

6889
IoU_max, IoU_argmax = torch.max(IoU, dim=1)

0 commit comments

Comments
 (0)