@@ -51,18 +51,39 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5151 bbox_annotation = annotations [j ]
5252 bbox_annotation = bbox_annotation [bbox_annotation [:, 4 ] != - 1 ]
5353
54+ classification = torch .clamp (classification , 1e-4 , 1.0 - 1e-4 )
55+
5456 if bbox_annotation .shape [0 ] == 0 :
55- if torch .cuda .is_available ():
57+ if torch .cuda .is_available ()
58+
59+ alpha_factor = torch .ones_like (classification ) * alpha
60+ alpha_factor = alpha_factor .cuda ()
61+ alpha_factor = 1. - alpha_factor
62+ focal_weight = classification
63+ focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
64+
65+ bce = - (torch .log (1.0 - classification ))
66+
67+ cls_loss = focal_weight * bce
68+
5669 regression_losses .append (torch .tensor (0 ).to (dtype ).cuda ())
57- classification_losses .append (torch . tensor ( 0 ). to ( dtype ). cuda ())
70+ classification_losses .append (cls_loss . sum ())
5871 else :
72+
73+ alpha_factor = torch .ones_like (classification ) * alpha
74+ alpha_factor = 1. - alpha_factor
75+ focal_weight = classification
76+ focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
77+
78+ bce = - (torch .log (1.0 - classification ))
79+
80+ cls_loss = focal_weight * bce
81+
5982 regression_losses .append (torch .tensor (0 ).to (dtype ))
60- classification_losses .append (torch . tensor ( 0 ). to ( dtype ))
83+ classification_losses .append (cls_loss . sum ( ))
6184
6285 continue
63-
64- classification = torch .clamp (classification , 1e-4 , 1.0 - 1e-4 )
65-
86+
6687 IoU = calc_iou (anchor [:, :], bbox_annotation [:, :4 ])
6788
6889 IoU_max , IoU_argmax = torch .max (IoU , dim = 1 )
0 commit comments