@@ -52,12 +52,36 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
5252 bbox_annotation = bbox_annotation [bbox_annotation [:, 4 ] != - 1 ]
5353
5454 if bbox_annotation .shape [0 ] == 0 :
55- if torch .cuda .is_available ():
55+ if torch .cuda .is_available ()
56+
57+ targets = torch .zeros_like (classification )
58+ alpha_factor = torch .ones_like (targets ) * alpha
59+ targets = targets .cuda ()
60+ alpha_factor = alpha_factor .cuda ()
61+ alphe_factot = 1. - alpha_factor
62+ focal_weight = classification
63+ focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
64+
65+ bce = - ((1.0 - targets ) * torch .log (1.0 - classification ))
66+
67+ cls_loss = focal_weight * bce
68+
5669 regression_losses .append (torch .tensor (0 ).to (dtype ).cuda ())
57- classification_losses .append (torch . tensor ( 0 ). to ( dtype ). cuda ( ))
70+ classification_losses .append (cls_loss . sum (), min = 1.0 ))
5871 else :
72+
73+ targets = torch .zeros_like (classification )
74+ alpha_factor = torch .ones_like (targets ) * alpha
75+ alphe_factot = 1. - alpha_factor
76+ focal_weight = classification
77+ focal_weight = alpha_factor * torch .pow (focal_weight , gamma )
78+
79+ bce = - ((1.0 - targets ) * torch .log (1.0 - classification ))
80+
81+ cls_loss = focal_weight * bce
82+
5983 regression_losses .append (torch .tensor (0 ).to (dtype ))
60- classification_losses .append (torch . tensor ( 0 ). to ( dtype ))
84+ classification_losses .append (cls_loss . sum (), min = 1.0 ))
6185
6286 continue
6387
0 commit comments