@@ -25,8 +25,11 @@ def calc_iou(a, b):
2525
2626
2727class FocalLoss (nn .Module ):
28- def __init__ (self ):
28+ def __init__ (self , matched_threshold = 0.5 , unmatched_threshold = 0.4 , negatives_lower_than_unmatched = True ):
2929 super (FocalLoss , self ).__init__ ()
30+ self .matched_threshold = matched_threshold
31+ self .unmatched_threshold = unmatched_threshold
32+ self .negatives_lower_than_unmatched = negatives_lower_than_unmatched
3033
3134 def forward (self , classifications , regressions , anchors , annotations , ** kwargs ):
3235 alpha = 0.25
@@ -69,26 +72,31 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
6972 else :
7073 regression_losses .append (torch .tensor (0 ).to (dtype ))
7174
72- # classification_losses.append(cls_loss.sum())
73- classification_losses .append (cls_loss .mean ())
74-
75+ classification_losses .append (cls_loss .sum ())
7576 continue
7677
7778 IoU = calc_iou (anchor [:, :], bbox_annotation [:, :4 ])
7879
7980 IoU_max , IoU_argmax = torch .max (IoU , dim = 1 )
8081
8182 # compute the loss for classification
82- targets = torch .ones_like (classification ) * - 1
83+ targets = torch .ones_like (classification ) * - 1 # init by ignoring all targets
8384 if torch .cuda .is_available ():
8485 targets = targets .cuda ()
8586
86- targets [torch .lt (IoU_max , 0.4 ), :] = 0
87-
88- positive_indices = torch .ge (IoU_max , 0.5 )
87+ if self .negatives_lower_than_unmatched :
88+ # negative matches are the ones below the unmatched_threshold
89+ targets [torch .lt (IoU_max , self .unmatched_threshold ), :] = 0
90+ else :
91+ # negative matches are in between the matched and unmatched
92+ targets [torch .lt (IoU_max , self .matched_threshold ) & torch .ge (IoU_max , self .unmatched_threshold ), :] = 0
8993
90- num_positive_anchors = positive_indices .sum ()
94+ # Find all positives in a batch for normalization
95+ positive_indices = torch .ge (IoU_max , self .matched_threshold )
9196
97+ # Avoid zero sum of num_positives, which would lead to inf loss during training
98+ num_positive_anchors = positive_indices .sum () + 1
99+ # print(num_positive_anchors)
92100 assigned_annotations = bbox_annotation [IoU_argmax , :]
93101
94102 targets [positive_indices , :] = 0
@@ -111,8 +119,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
111119 zeros = zeros .cuda ()
112120 cls_loss = torch .where (torch .ne (targets , - 1.0 ), cls_loss , zeros )
113121
114- # classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
115- classification_losses .append (cls_loss .mean () / torch .clamp (num_positive_anchors .to (dtype ), min = 1.0 ))
122+ classification_losses .append (cls_loss .sum () / torch .clamp (num_positive_anchors .to (dtype ), min = 1.0 ))
116123
117124 if positive_indices .sum () > 0 :
118125 assigned_annotations = assigned_annotations [positive_indices , :]
0 commit comments