Skip to content

Commit 502c2f2

Browse files
authored
Merge pull request #470 from shtdbb/master
Update negative sampling in hn_mine to fix issue #464
2 parents 34d24e8 + 4c2cc14 commit 502c2f2

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

  • FlagEmbedding/baai_general_embedding/finetune

FlagEmbedding/baai_general_embedding/finetune/hn_mine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
100100
with open(output_file, 'w') as f:
101101
for data in train_data:
102102
if len(data['neg']) < negative_number:
103-
data['neg'].extend(random.sample(corpus, negative_number - len(data['neg'])))
103+
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
104+
samples = [sent for sent in samples if sent not in data['pos']]
105+
data['neg'].extend(samples[: negative_number - len(data['neg'])])
104106
f.write(json.dumps(data, ensure_ascii=False) + '\n')
105107

106108

0 commit comments

Comments
 (0)