Skip to content

Commit 4c2cc14

Browse files
committed
Fix negative samples in hn_mine
1 parent 6a504b9 commit 4c2cc14

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

  • FlagEmbedding/baai_general_embedding/finetune

FlagEmbedding/baai_general_embedding/finetune/hn_mine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,9 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
100100
with open(output_file, 'w') as f:
101101
for data in train_data:
102102
if len(data['neg']) < negative_number:
103-
candidates = list(set(corpus) - set(data['pos'] + data['neg']))
104-
if not candidates:
105-
candidates = list(set(corpus) - set(data['pos']))
106-
data['neg'].extend(random.sample(candidates, negative_number - len(data['neg'])))
103+
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
104+
samples = [sent for sent in samples if sent not in data['pos']]
105+
data['neg'].extend(samples[: negative_number - len(data['neg'])])
107106
f.write(json.dumps(data, ensure_ascii=False) + '\n')
108107

109108

0 commit comments

Comments
 (0)