Skip to content

Commit 0b5a328

Browse files
committed
Fixed classification loss normalization + support for matched/unmatched anchors
1 parent 02ef960 commit 0b5a328

2 files changed

Lines changed: 30 additions & 15 deletions

File tree

efficientdet/loss.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ def calc_iou(a, b):
2525

2626

2727
class FocalLoss(nn.Module):
28-
def __init__(self):
28+
def __init__(self, matched_threshold=0.5, unmatched_threshold=0.4, negatives_lower_than_unmatched=False):
2929
super(FocalLoss, self).__init__()
30+
self.matched_threshold = matched_threshold
31+
self.unmatched_threshold = unmatched_threshold
32+
self.negatives_lower_than_unmatched = negatives_lower_than_unmatched
3033

3134
def forward(self, classifications, regressions, anchors, annotations, **kwargs):
3235
alpha = 0.25
@@ -69,9 +72,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
6972
else:
7073
regression_losses.append(torch.tensor(0).to(dtype))
7174

72-
# classification_losses.append(cls_loss.sum())
73-
classification_losses.append(cls_loss.mean())
74-
75+
classification_losses.append(cls_loss.sum())
7576
continue
7677

7778
IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
@@ -83,12 +84,19 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
8384
if torch.cuda.is_available():
8485
targets = targets.cuda()
8586

86-
targets[torch.lt(IoU_max, 0.4), :] = 0
87-
88-
positive_indices = torch.ge(IoU_max, 0.5)
87+
if self.negatives_lower_than_unmatched:
88+
# negative matches are the ones below the unmatched_threshold, whereas ignored matches are in between the matched and unmatched
89+
targets[torch.lt(IoU_max, self.matched_threshold) & torch.ge(IoU_max, self.unmatched_threshold), :] = 0
90+
else:
91+
# Ignore targets with overlap lower than unmatched_threshold
92+
targets[torch.lt(IoU_max, self.unmatched_threshold), :] = 0
8993

90-
num_positive_anchors = positive_indices.sum()
94+
# Find all positives in a batch for normalization
95+
positive_indices = torch.ge(IoU_max, self.matched_threshold)
9196

97+
# Avoid zero sum of num_positives, which would lead to inf loss during training
98+
num_positive_anchors = positive_indices.sum() + 1
99+
# print(num_positive_anchors)
92100
assigned_annotations = bbox_annotation[IoU_argmax, :]
93101

94102
targets[positive_indices, :] = 0
@@ -111,8 +119,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
111119
zeros = zeros.cuda()
112120
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
113121

114-
# classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
115-
classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
122+
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
116123

117124
if positive_indices.sum() > 0:
118125
assigned_annotations = assigned_annotations[positive_indices, :]

train.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_args():
3737
parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
3838
parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader')
3939
parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices')
40-
parser.add_argument('--head_only', type=bool, default=False,
40+
parser.add_argument('--head_only', type=boolean_string, default=False,
4141
help='whether finetunes only the regressor and the classifier, '
4242
'useful in early stage convergence or small/easy dataset')
4343
parser.add_argument('--lr', type=float, default=1e-4)
@@ -56,17 +56,25 @@ def get_args():
5656
parser.add_argument('-w', '--load_weights', type=str, default=None,
5757
help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
5858
parser.add_argument('--saved_path', type=str, default='logs/')
59-
parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, '
59+
parser.add_argument('--debug', type=boolean_string, default=False, help='whether visualize the predicted boxes of trainging, '
6060
'the output images will be in test/')
61+
parser.add_argument('--matched_threshold', type=float, default=.5, help='Threshold for positive matches.')
62+
parser.add_argument('--unmatched_threshold', type=float, default=.4, help='Threshold for negative matches.')
6163

6264
args = parser.parse_args()
6365
return args
6466

6567

68+
def boolean_string(s):
69+
if s not in {'False', 'True'}:
70+
raise ValueError('Not a valid boolean string')
71+
return s == 'True'
72+
73+
6674
class ModelWithLoss(nn.Module):
67-
def __init__(self, model, debug=False):
75+
def __init__(self, model, matched_threshold=0.5, unmatched_threshold=0.4, debug=False):
6876
super().__init__()
69-
self.criterion = FocalLoss()
77+
self.criterion = FocalLoss(matched_threshold=matched_threshold, unmatched_threshold=unmatched_threshold)
7078
self.model = model
7179
self.debug = debug
7280

@@ -175,7 +183,7 @@ def freeze_backbone(m):
175183
writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
176184

177185
# warp the model with loss function, to reduce the memory usage on gpu0 and speedup
178-
model = ModelWithLoss(model, debug=opt.debug)
186+
model = ModelWithLoss(model, matched_threshold=opt.matched_threshold, unmatched_threshold=opt.unmatched_threshold, debug=opt.debug)
179187

180188
if params.num_gpus > 0:
181189
model = model.cuda()

0 commit comments

Comments
 (0)