Skip to content

Commit 4174703

Browse files
authored
Update loss.py
1 parent 40e54e3 commit 4174703

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

efficientdet/loss.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,32 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5656
if bbox_annotation.shape[0] == 0:
5757
if torch.cuda.is_available()
5858

59-
targets = torch.zeros_like(classification)
60-
alpha_factor = torch.ones_like(targets) * alpha
59+
alpha_factor = torch.ones_like(classification) * alpha
6160
targets = targets.cuda()
6261
alpha_factor = alpha_factor.cuda()
6362
alphe_factot = 1. - alpha_factor
6463
focal_weight = classification
6564
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
6665

67-
bce = -((1.0 - targets) * torch.log(1.0 - classification))
66+
bce = -(torch.log(1.0 - classification))
6867

6968
cls_loss = focal_weight * bce
7069

7170
regression_losses.append(torch.tensor(0).to(dtype).cuda())
72-
classification_losses.append(cls_loss.sum(), min=1.0)
71+
classification_losses.append(cls_loss.sum())
7372
else:
7473

75-
targets = torch.zeros_like(classification)
76-
alpha_factor = torch.ones_like(targets) * alpha
74+
alpha_factor = torch.ones_like(classification) * alpha
7775
alphe_factot = 1. - alpha_factor
7876
focal_weight = classification
7977
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
8078

81-
bce = -((1.0 - targets) * torch.log(1.0 - classification))
79+
bce = -(torch.log(1.0 - classification))
8280

8381
cls_loss = focal_weight * bce
8482

8583
regression_losses.append(torch.tensor(0).to(dtype))
86-
classification_losses.append(cls_loss.sum(), min=1.0)
84+
classification_losses.append(cls_loss.sum())
8785

8886
continue
8987

0 commit comments

Comments
 (0)