Skip to content

Commit 61ce49e

Browse files
authored
Merge pull request #605 from hanhainebula/master
Fix a bug in BGE_M3/split_data_by_length.py, and update README of MKQA
2 parents f961f12 + 2748d51 commit 61ce49e

2 files changed

Lines changed: 34 additions & 4 deletions

File tree

C_MTEB/MKQA/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,26 @@ We use the well-processed NQ [corpus](https://huggingface.co/datasets/BeIR/nq) o
1616

1717
If you only want to perform dense retrieval with embedding models, you can follow the following steps:
1818

19+
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
20+
21+
```bash
22+
# install java (Linux)
23+
apt update
24+
apt install openjdk-11-jdk
25+
26+
# install pyserini
27+
pip install pyserini
28+
29+
# install faiss
30+
## CPU version
31+
conda install -c conda-forge faiss-cpu
32+
33+
## GPU version
34+
conda install -c conda-forge faiss-gpu
35+
```
36+
37+
2. Dense retrieval:
38+
1939
```bash
2040
cd dense_retrieval
2141

FlagEmbedding/BGE_M3/split_data_by_length.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import math
1515
import time
1616
import argparse
17+
import datasets
1718
from tqdm import tqdm
1819
from pprint import pprint
1920
from transformers import AutoTokenizer
@@ -54,8 +55,7 @@ def _map_func(examples):
5455
results['idx'] = []
5556
results['max_length'] = []
5657
for i in range(len(examples['query'])):
57-
results['idx'].append(i)
58-
58+
idx = examples['idx'][i]
5959
query = examples['query'][i]
6060
pos, neg = examples['pos'][i], examples['neg'][i]
6161
all_texts = [query] + pos + neg
@@ -65,6 +65,8 @@ def _map_func(examples):
6565
tokenized_x = self.tokenizer(x)['input_ids']
6666
if len(tokenized_x) > max_len:
6767
max_len = len(tokenized_x)
68+
69+
results['idx'].append(idx)
6870
results['max_length'].append(max_len)
6971
return results
7072

@@ -120,8 +122,15 @@ def _process_file(self, file_path: str, output_path: str):
120122
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=features)['train']
121123
except:
122124
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=kd_features)['train']
123-
mapped_dataset = dataset.map(self._map_func, batched=True, num_proc=self.num_proc)
124125

126+
dataset_with_idx_list = []
127+
for i, data in enumerate(dataset):
128+
data['idx'] = i
129+
dataset_with_idx_list.append(data)
130+
dataset_with_idx = datasets.Dataset.from_list(dataset_with_idx_list)
131+
132+
mapped_dataset = dataset_with_idx.map(self._map_func, batched=True, num_proc=self.num_proc)
133+
125134
split_info_dict = {}
126135
for length_l, length_r in self.length_ranges_list:
127136
save_path = output_path + f'_len-{length_l}-{length_r}.jsonl'
@@ -130,7 +139,8 @@ def _process_file(self, file_path: str, output_path: str):
130139
continue
131140

132141
idxs = mapped_dataset.filter(lambda x: length_l <= x['max_length'] < length_r, num_proc=self.num_proc)
133-
split_dataset = dataset.select(list(idxs._indices.to_pandas()['indices'].values))
142+
split_dataset = dataset_with_idx.select(idxs['idx'])
143+
split_dataset = split_dataset.remove_columns('idx')
134144

135145
split_info_dict[f'len-{length_l}-{length_r}'] = len(split_dataset)
136146

0 commit comments

Comments
 (0)