Skip to content

Commit 6856d64

Browse files
author
shtdbb
committed
Fix negative samples in hn_mine
1 parent db672e1 commit 6856d64

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

  • FlagEmbedding/baai_general_embedding/finetune

FlagEmbedding/baai_general_embedding/finetune/hn_mine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
100100
with open(output_file, 'w') as f:
101101
for data in train_data:
102102
if len(data['neg']) < negative_number:
103-
data['neg'].extend(random.sample(corpus, negative_number - len(data['neg'])))
103+
candidates = list(set(corpus) - set(data['pos'] + data['neg']))
104+
if not candidates:
105+
candidates = list(set(corpus) - set(data['pos']))
106+
data['neg'].extend(random.sample(candidates, negative_number - len(data['neg'])))
104107
f.write(json.dumps(data, ensure_ascii=False) + '\n')
105108

106109

0 commit comments

Comments
 (0)