@@ -53,34 +53,24 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5353
5454 classification = torch .clamp (classification , 1e-4 , 1.0 - 1e-4 )
5555
56- if bbox_annotation .shape [0 ] == 0 :
57- if torch .cuda .is_available ():
58-
59- alpha_factor = torch .ones_like (classification ) * alpha
56+ if len (bbox_annotation ) == 0 : # No annotations
57+ alpha_factor = torch .ones_like (classification ) * alpha
58+ if torch .cuda .is_available ():
6059 alpha_factor = alpha_factor .cuda ()
61- alpha_factor = 1. - alpha_factor
62- focal_weight = classification
63- focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
60+ alpha_factor = 1. - alpha_factor
61+ focal_weight = classification
62+ focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
6463
65- bce = - (torch .log (1.0 - classification ))
66-
67- cls_loss = focal_weight * bce
64+ bce = - (torch .log (1.0 - classification ))
6865
66+ cls_loss = focal_weight * bce
67+ if torch .cuda .is_available ():
6968 regression_losses .append (torch .tensor (0 ).to (dtype ).cuda ())
70- classification_losses .append (cls_loss .sum ())
7169 else :
72-
73- alpha_factor = torch .ones_like (classification ) * alpha
74- alpha_factor = 1. - alpha_factor
75- focal_weight = classification
76- focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
77-
78- bce = - (torch .log (1.0 - classification ))
79-
80- cls_loss = focal_weight * bce
81-
8270 regression_losses .append (torch .tensor (0 ).to (dtype ))
83- classification_losses .append (cls_loss .sum ())
71+
72+ # classification_losses.append(cls_loss.sum())
73+ classification_losses .append (cls_loss .mean ())
8474
8575 continue
8676
@@ -121,7 +111,8 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
121111 zeros = zeros .cuda ()
122112 cls_loss = torch .where (torch .ne (targets , - 1.0 ), cls_loss , zeros )
123113
124- classification_losses .append (cls_loss .sum () / torch .clamp (num_positive_anchors .to (dtype ), min = 1.0 ))
114+ # classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
115+ classification_losses .append (cls_loss .mean () / torch .clamp (num_positive_anchors .to (dtype ), min = 1.0 ))
125116
126117 if positive_indices .sum () > 0 :
127118 assigned_annotations = assigned_annotations [positive_indices , :]
0 commit comments