Skip to content

Commit 02ef960

Browse files
committed
Readability refactor + normalize classification loss
1 parent 3fe4552 commit 02ef960

1 file changed

Lines changed: 14 additions & 23 deletions

File tree

efficientdet/loss.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,34 +53,24 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5353

5454
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
5555

56-
if bbox_annotation.shape[0] == 0:
57-
if torch.cuda.is_available():
58-
59-
alpha_factor = torch.ones_like(classification) * alpha
56+
if len(bbox_annotation) == 0: # No annotations
57+
alpha_factor = torch.ones_like(classification) * alpha
58+
if torch.cuda.is_available():
6059
alpha_factor = alpha_factor.cuda()
61-
alpha_factor = 1. - alpha_factor
62-
focal_weight = classification
63-
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
60+
alpha_factor = 1. - alpha_factor
61+
focal_weight = classification
62+
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
6463

65-
bce = -(torch.log(1.0 - classification))
66-
67-
cls_loss = focal_weight * bce
64+
bce = -(torch.log(1.0 - classification))
6865

66+
cls_loss = focal_weight * bce
67+
if torch.cuda.is_available():
6968
regression_losses.append(torch.tensor(0).to(dtype).cuda())
70-
classification_losses.append(cls_loss.sum())
7169
else:
72-
73-
alpha_factor = torch.ones_like(classification) * alpha
74-
alpha_factor = 1. - alpha_factor
75-
focal_weight = classification
76-
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
77-
78-
bce = -(torch.log(1.0 - classification))
79-
80-
cls_loss = focal_weight * bce
81-
8270
regression_losses.append(torch.tensor(0).to(dtype))
83-
classification_losses.append(cls_loss.sum())
71+
72+
# classification_losses.append(cls_loss.sum())
73+
classification_losses.append(cls_loss.mean())
8474

8575
continue
8676

@@ -121,7 +111,8 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
121111
zeros = zeros.cuda()
122112
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
123113

124-
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
114+
# classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
115+
classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
125116

126117
if positive_indices.sum() > 0:
127118
assigned_annotations = assigned_annotations[positive_indices, :]

0 commit comments

Comments
 (0)