Skip to content

Commit 3f26d33

Browse files
authored
Merge pull request #1149 from ZiyiXia/master
Update C-MTEB
2 parents 2bfc922 + 44738e5 commit 3f26d33

7 files changed

Lines changed: 338 additions & 111 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,8 @@ pic2.py
136136

137137
# Pyre type checker
138138
.pyre/
139+
140+
# results
141+
results/
142+
en_results/
143+
zh_results/

C_MTEB/C_MTEB/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
# from .tasks import *
12
from .tasks import *
23

3-
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
4-
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
5-
'Ocnli', 'Cmnli',
6-
'T2Reranking', 'MMarcoReranking', 'CMedQAv1', 'CMedQAv2',
7-
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
8-
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']
4+
ChineseTaskList = [
5+
'TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
6+
'CLSClusteringS2S.v2', 'CLSClusteringP2P.v2', 'ThuNewsClusteringS2S.v2', 'ThuNewsClusteringP2P.v2',
7+
'Ocnli', 'Cmnli',
8+
'T2Reranking', 'MMarcoReranking', 'CMedQAv1-reranking', 'CMedQAv2-reranking',
9+
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
10+
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC'
11+
]
912

1013

C_MTEB/eval_C-MTEB.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from C_MTEB.tasks import *
3+
from C_MTEB import ChineseTaskList
44
from flag_dres_model import FlagDRESModel
55
from mteb import MTEB
66

@@ -24,24 +24,20 @@ def get_args():
2424
return parser.parse_args()
2525

2626

27-
2827
if __name__ == '__main__':
2928
args = get_args()
3029

3130
model = FlagDRESModel(model_name_or_path=args.model_name_or_path,
3231
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
3332
pooling_method=args.pooling_method)
33+
34+
print(ChineseTaskList)
3435

35-
task_names = [t.description["name"] for t in MTEB(task_types=args.task_type,
36-
task_langs=['zh', 'zh-CN']).tasks]
37-
38-
for task in task_names:
39-
# if task not in ChineseTaskList:
40-
# continue
36+
for task in ChineseTaskList:
4137
if task in ['T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval',
4238
'CovidRetrieval', 'CmedqaRetrieval',
4339
'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
44-
'T2Reranking', 'MMarcoReranking', 'CMedQAv1', 'CMedQAv2']:
40+
'T2Reranking', 'MMarcoReranking', 'CMedQAv1-reranking', 'CMedQAv2-reranking']:
4541
if args.model_name_or_path not in query_instruction_for_retrieval_dict:
4642
if args.add_instruction:
4743
instruction = "为这个句子生成表示以用于检索相关文章:"
@@ -55,8 +51,5 @@ def get_args():
5551

5652
model.query_instruction_for_retrieval = instruction
5753

58-
evaluation = MTEB(tasks=[task], task_langs=['zh', 'zh-CN'])
54+
evaluation = MTEB(tasks=[task])
5955
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")
60-
61-
62-

C_MTEB/eval_MTEB.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def get_args():
3131
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
3232
pooling_method=args.pooling_method)
3333

34-
task_names = [t.description["name"] for t in MTEB(task_types=args.task_type,
34+
# task_names = [t.description["name"] for t in MTEB(task_types=args.task_type,
35+
# task_langs=['en']).tasks]
36+
37+
task_names = [t.metadata.name for t in MTEB(task_types=args.task_type,
3538
task_langs=['en']).tasks]
3639

3740
for task in task_names:

C_MTEB/summarize_results.py

Lines changed: 73 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,47 @@
44
from collections import defaultdict
55

66
from C_MTEB import *
7+
import mteb
78
from mteb import MTEB
89

910

10-
def read_results(task_types, except_tasks, args):
11+
CMTEB_tasks = [
12+
'TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
13+
'CLSClusteringS2S.v2', 'CLSClusteringP2P.v2', 'ThuNewsClusteringS2S.v2', 'ThuNewsClusteringP2P.v2',
14+
'Ocnli', 'Cmnli',
15+
'T2Reranking', 'MMarcoReranking', 'CMedQAv1-reranking', 'CMedQAv2-reranking',
16+
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
17+
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC'
18+
]
19+
20+
21+
def read_results(task_types, args):
1122
tasks_results = {}
12-
model_dirs = {}
23+
# model_dirs = {}
1324
for t_type in task_types:
1425
tasks_results[t_type] = {}
15-
for t in MTEB(task_types=[t_type], task_langs=args.lang).tasks:
16-
task_name = t.description["name"]
17-
if task_name in except_tasks: continue
26+
for t in mteb.get_tasks(task_types=[t_type]):
27+
task_name = t.metadata.name
28+
if task_name not in CMTEB_tasks:
29+
continue
1830

19-
metric = t.description["main_score"]
31+
metric = t.metadata.main_score
2032
tasks_results[t_type][task_name] = defaultdict(None)
2133

22-
for model_name in os.listdir(args.results_dir):
23-
model_dir = os.path.join(args.results_dir, model_name)
24-
if not os.path.isdir(model_dir): continue
25-
model_dirs[model_name] = model_dir
26-
if os.path.exists(os.path.join(model_dir, task_name + '.json')):
27-
data = json.load(open(os.path.join(model_dir, task_name + '.json')))
28-
for s in ['test', 'dev', 'validation']:
29-
if s in data:
30-
split = s
31-
break
32-
33-
if 'en' in args.lang:
34-
if 'en-en' in data[split]:
35-
temp_data = data[split]['en-en']
36-
elif 'en' in data[split]:
37-
temp_data = data[split]['en']
38-
else:
39-
temp_data = data[split]
40-
elif 'zh' in args.lang:
41-
if 'zh' in data[split]:
42-
temp_data = data[split]['zh']
43-
elif 'zh-CN' in data[split]:
44-
temp_data = data[split]['zh-CN']
45-
else:
46-
temp_data = data[split]
47-
48-
if metric == 'ap':
49-
tasks_results[t_type][task_name][model_name] = round(temp_data['cos_sim']['ap'] * 100, 2)
50-
elif metric == 'cosine_spearman':
51-
tasks_results[t_type][task_name][model_name] = round(temp_data['cos_sim']['spearman'] * 100, 2)
52-
else:
53-
tasks_results[t_type][task_name][model_name] = round(temp_data[metric] * 100, 2)
54-
55-
return tasks_results, model_dirs
56-
57-
58-
def output_markdown(tasks_results, model_names, save_file):
34+
if os.path.exists(os.path.join(args.results_dir, task_name + '.json')):
35+
data = json.load(open(os.path.join(args.results_dir, task_name + '.json')))
36+
for s in ['test', 'dev', 'validation']:
37+
if s in data['scores']:
38+
split = s
39+
break
40+
41+
temp_data = data['scores'][split][0]
42+
tasks_results[t_type][task_name] = round(temp_data[metric] * 100, 2)
43+
44+
return tasks_results
45+
46+
47+
def output_markdown(tasks_results, model, save_file):
5948
task_type_res = {}
6049
with open(save_file, 'w') as f:
6150
for t_type, type_results in tasks_results.items():
@@ -79,33 +68,29 @@ def output_markdown(tasks_results, model_names, save_file):
7968
f.write(first_line + ' Avg | \n')
8069
f.write(second_line + ':--------:| \n')
8170

82-
for model in model_names:
83-
write_line = f"| {model} |"
84-
all_res = []
85-
cqa_res = []
86-
for task_name, results in type_results.items():
87-
if "CQADupstack" in task_name:
88-
if model in results:
89-
cqa_res.append(results[model])
90-
continue
91-
71+
write_line = f"| {model} |"
72+
all_res = []
73+
cqa_res = []
74+
for task_name, results in type_results.items():
75+
if "CQADupstack" in task_name:
9276
if model in results:
93-
write_line += f" {results[model]} |"
94-
all_res.append(results[model])
95-
else:
96-
write_line += f" |"
97-
98-
if len(cqa_res) > 0:
99-
write_line += f" {round(sum(cqa_res) / len(cqa_res), 2)} |"
100-
all_res.append(round(sum(cqa_res) / len(cqa_res), 2))
101-
102-
# if len(all_res) == len(type_results.keys()):
103-
if len(all_res) == task_cnt:
104-
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
105-
task_type_res[t_type][model] = all_res
106-
else:
107-
write_line += f" |"
108-
f.write(write_line + ' \n')
77+
cqa_res.append(results[model])
78+
continue
79+
80+
write_line += f" {results} |"
81+
all_res.append(results)
82+
83+
if len(cqa_res) > 0:
84+
write_line += f" {round(sum(cqa_res) / len(cqa_res), 2)} |"
85+
all_res.append(round(sum(cqa_res) / len(cqa_res), 2))
86+
87+
# if len(all_res) == len(type_results.keys()):
88+
if len(all_res) == task_cnt:
89+
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
90+
task_type_res[t_type][model] = all_res
91+
else:
92+
write_line += f" |"
93+
f.write(write_line + ' \n\n')
10994

11095
f.write(f'Overall \n')
11196
first_line = "| Model |"
@@ -116,47 +101,45 @@ def output_markdown(tasks_results, model_names, save_file):
116101
f.write(first_line + ' Avg | \n')
117102
f.write(second_line + ':--------:| \n')
118103

119-
for model in model_names:
120-
write_line = f"| {model} |"
121-
all_res = []
122-
for type_name, results in task_type_res.items():
123-
if model in results:
124-
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
125-
all_res.extend(results[model])
126-
else:
127-
write_line += f" |"
128-
129-
if len(all_res) > 0:
130-
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
104+
write_line = f"| {model} |"
105+
all_res = []
106+
for type_name, results in task_type_res.items():
107+
if model in results:
108+
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
109+
all_res.extend(results[model])
110+
else:
111+
write_line += f" |"
112+
113+
if len(all_res) > 0:
114+
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
131115

132-
f.write(write_line + ' \n')
116+
f.write(write_line + ' \n')
133117

134118

135119
def get_args():
136120
parser = argparse.ArgumentParser()
137121
parser.add_argument('--results_dir', default="./zh_results", type=str)
138122
parser.add_argument('--lang', default="zh", type=str)
123+
parser.add_argument('--model', default="model", type=str)
139124
return parser.parse_args()
140125

141126

142127
if __name__ == '__main__':
143128
args = get_args()
144129

145-
if args.lang == 'zh':
130+
if args.lang == 'zho':
146131
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
147-
except_tasks = []
148-
args.lang = ['zh', 'zh-CN']
149-
elif args.lang == 'en':
132+
args.lang = ['zho']
133+
elif args.lang == 'eng':
150134
task_types = ["Retrieval", "Clustering", "PairClassification", "Reranking", "STS", "Summarization",
151135
"Classification"]
152-
except_tasks = ['MSMARCOv2']
153-
args.lang = ['en']
136+
args.lang = ['eng']
154137
else:
155138
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
156139

157-
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
140+
task_results = read_results(task_types, args=args)
158141

159-
output_markdown(task_results, model_dirs.keys(),
142+
output_markdown(task_results, args.model,
160143
save_file=os.path.join(args.results_dir, f'{args.lang[0]}_results.md'))
161144

162145

Tutorials/4_Evaluation/4.2.2_MTEB_Leaderboard.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
"for task in MTEB_MAIN_EN.tasks:\n",
144144
" # get the test set to evaluate on\n",
145145
" eval_splits = [\"dev\"] if task == \"MSMARCO\" else [\"test\"]\n",
146-
" evaluation = MTEB(\n",
146+
" evaluation = mteb.MTEB(\n",
147147
" tasks=[task], task_langs=[\"en\"]\n",
148148
" ) # Remove \"en\" to run all available languages\n",
149149
" evaluation.run(\n",

0 commit comments

Comments
 (0)