-
Notifications
You must be signed in to change notification settings - Fork 869
Expand file tree
/
Copy pathhn_mine.py
More file actions
130 lines (105 loc) · 4.96 KB
/
hn_mine.py
File metadata and controls
130 lines (105 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import json
import random
import numpy as np
import faiss
from tqdm import tqdm
from FlagEmbedding import FlagModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
parser.add_argument('--input_file', default=None, type=str)
parser.add_argument('--candidate_pool', default=None, type=str)
parser.add_argument('--output_file', default=None, type=str)
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
parser.add_argument('--similarity_range', default="0.0-1.0", type=str, help="similarity range to sample negatives")
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
parser.add_argument('--query_instruction_for_retrieval', default="")
return parser.parse_args()
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def batch_search(index,
query,
topk: int = 200,
batch_size: int = 64):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
all_scores.extend(batch_scores.tolist())
all_inxs.extend(batch_inxs.tolist())
return all_scores, all_inxs
def get_corpus(candidate_pool):
corpus = []
for line in open(candidate_pool):
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, similarity_range, negative_number, use_gpu):
corpus = []
queries = []
train_data = []
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
corpus.extend(line['pos'])
if 'neg' in line:
corpus.extend(line['neg'])
queries.append(line['query'])
if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
corpus = list(set(candidate_pool))
else:
corpus = list(set(corpus))
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
p_vecs = model.encode(corpus, batch_size=256)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries, batch_size=256)
print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
all_scores, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
assert len(all_inxs) == len(train_data)
min_sim, max_sim = similarity_range
for i, data in enumerate(train_data):
query = data['query']
scores = all_scores[i]
inxs = all_inxs[i]
inxs = inxs[sample_range[0]:sample_range[1]]
scores = scores[sample_range[0]:sample_range[1]]
filtered_inx = []
for score, inx in zip(scores, inxs):
if min_sim <= score <= max_sim and corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)
if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]
with open(output_file, 'w') as f:
for data in train_data:
if len(data['neg']) < negative_number:
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[: negative_number - len(data['neg'])])
f.write(json.dumps(data, ensure_ascii=False) + '\n')
if __name__ == '__main__':
args = get_args()
sample_range = list(map(int, args.range_for_sampling.split('-')))
similarity_range = list(map(float, args.similarity_range.split('-')))
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
find_knn_neg(model,
input_file=args.input_file,
candidate_pool=args.candidate_pool,
output_file=args.output_file,
sample_range=sample_range,
similarity_range=similarity_range,
negative_number=args.negative_number,
use_gpu=args.use_gpu_for_searching)