44from collections import defaultdict
55
66from C_MTEB import *
7+ import mteb
78from mteb import MTEB
89
910
10- def read_results (task_types , except_tasks , args ):
11+ CMTEB_tasks = [
12+ 'TNews' , 'IFlyTek' , 'MultilingualSentiment' , 'JDReview' , 'OnlineShopping' , 'Waimai' ,
13+ 'CLSClusteringS2S.v2' , 'CLSClusteringP2P.v2' , 'ThuNewsClusteringS2S.v2' , 'ThuNewsClusteringP2P.v2' ,
14+ 'Ocnli' , 'Cmnli' ,
15+ 'T2Reranking' , 'MMarcoReranking' , 'CMedQAv1-reranking' , 'CMedQAv2-reranking' ,
16+ 'T2Retrieval' , 'MMarcoRetrieval' , 'DuRetrieval' , 'CovidRetrieval' , 'CmedqaRetrieval' , 'EcomRetrieval' , 'MedicalRetrieval' , 'VideoRetrieval' ,
17+ 'ATEC' , 'BQ' , 'LCQMC' , 'PAWSX' , 'STSB' , 'AFQMC' , 'QBQTC'
18+ ]
19+
20+
21+ def read_results (task_types , args ):
1122 tasks_results = {}
12- model_dirs = {}
23+ # model_dirs = {}
1324 for t_type in task_types :
1425 tasks_results [t_type ] = {}
15- for t in MTEB (task_types = [t_type ], task_langs = args .lang ).tasks :
16- task_name = t .description ["name" ]
17- if task_name in except_tasks : continue
26+ for t in mteb .get_tasks (task_types = [t_type ]):
27+ task_name = t .metadata .name
28+ if task_name not in CMTEB_tasks :
29+ continue
1830
19- metric = t .description [ " main_score" ]
31+ metric = t .metadata . main_score
2032 tasks_results [t_type ][task_name ] = defaultdict (None )
2133
22- for model_name in os .listdir (args .results_dir ):
23- model_dir = os .path .join (args .results_dir , model_name )
24- if not os .path .isdir (model_dir ): continue
25- model_dirs [model_name ] = model_dir
26- if os .path .exists (os .path .join (model_dir , task_name + '.json' )):
27- data = json .load (open (os .path .join (model_dir , task_name + '.json' )))
28- for s in ['test' , 'dev' , 'validation' ]:
29- if s in data :
30- split = s
31- break
32-
33- if 'en' in args .lang :
34- if 'en-en' in data [split ]:
35- temp_data = data [split ]['en-en' ]
36- elif 'en' in data [split ]:
37- temp_data = data [split ]['en' ]
38- else :
39- temp_data = data [split ]
40- elif 'zh' in args .lang :
41- if 'zh' in data [split ]:
42- temp_data = data [split ]['zh' ]
43- elif 'zh-CN' in data [split ]:
44- temp_data = data [split ]['zh-CN' ]
45- else :
46- temp_data = data [split ]
47-
48- if metric == 'ap' :
49- tasks_results [t_type ][task_name ][model_name ] = round (temp_data ['cos_sim' ]['ap' ] * 100 , 2 )
50- elif metric == 'cosine_spearman' :
51- tasks_results [t_type ][task_name ][model_name ] = round (temp_data ['cos_sim' ]['spearman' ] * 100 , 2 )
52- else :
53- tasks_results [t_type ][task_name ][model_name ] = round (temp_data [metric ] * 100 , 2 )
54-
55- return tasks_results , model_dirs
56-
57-
58- def output_markdown (tasks_results , model_names , save_file ):
34+ if os .path .exists (os .path .join (args .results_dir , task_name + '.json' )):
35+ data = json .load (open (os .path .join (args .results_dir , task_name + '.json' )))
36+ for s in ['test' , 'dev' , 'validation' ]:
37+ if s in data ['scores' ]:
38+ split = s
39+ break
40+
41+ temp_data = data ['scores' ][split ][0 ]
42+ tasks_results [t_type ][task_name ] = round (temp_data [metric ] * 100 , 2 )
43+
44+ return tasks_results
45+
46+
47+ def output_markdown (tasks_results , model , save_file ):
5948 task_type_res = {}
6049 with open (save_file , 'w' ) as f :
6150 for t_type , type_results in tasks_results .items ():
@@ -79,33 +68,29 @@ def output_markdown(tasks_results, model_names, save_file):
7968 f .write (first_line + ' Avg | \n ' )
8069 f .write (second_line + ':--------:| \n ' )
8170
82- for model in model_names :
83- write_line = f"| { model } |"
84- all_res = []
85- cqa_res = []
86- for task_name , results in type_results .items ():
87- if "CQADupstack" in task_name :
88- if model in results :
89- cqa_res .append (results [model ])
90- continue
91-
71+ write_line = f"| { model } |"
72+ all_res = []
73+ cqa_res = []
74+ for task_name , results in type_results .items ():
75+ if "CQADupstack" in task_name :
9276 if model in results :
93- write_line += f" { results [model ]} |"
94- all_res .append (results [model ])
95- else :
96- write_line += f" |"
97-
98- if len (cqa_res ) > 0 :
99- write_line += f" { round (sum (cqa_res ) / len (cqa_res ), 2 )} |"
100- all_res .append (round (sum (cqa_res ) / len (cqa_res ), 2 ))
101-
102- # if len(all_res) == len(type_results.keys()):
103- if len (all_res ) == task_cnt :
104- write_line += f" { round (sum (all_res ) / len (all_res ), 2 )} |"
105- task_type_res [t_type ][model ] = all_res
106- else :
107- write_line += f" |"
108- f .write (write_line + ' \n ' )
77+ cqa_res .append (results [model ])
78+ continue
79+
80+ write_line += f" { results } |"
81+ all_res .append (results )
82+
83+ if len (cqa_res ) > 0 :
84+ write_line += f" { round (sum (cqa_res ) / len (cqa_res ), 2 )} |"
85+ all_res .append (round (sum (cqa_res ) / len (cqa_res ), 2 ))
86+
87+ # if len(all_res) == len(type_results.keys()):
88+ if len (all_res ) == task_cnt :
89+ write_line += f" { round (sum (all_res ) / len (all_res ), 2 )} |"
90+ task_type_res [t_type ][model ] = all_res
91+ else :
92+ write_line += f" |"
93+ f .write (write_line + ' \n \n ' )
10994
11095 f .write (f'Overall \n ' )
11196 first_line = "| Model |"
@@ -116,47 +101,45 @@ def output_markdown(tasks_results, model_names, save_file):
116101 f .write (first_line + ' Avg | \n ' )
117102 f .write (second_line + ':--------:| \n ' )
118103
119- for model in model_names :
120- write_line = f"| { model } |"
121- all_res = []
122- for type_name , results in task_type_res .items ():
123- if model in results :
124- write_line += f" { round (sum (results [model ]) / len (results [model ]), 2 )} |"
125- all_res .extend (results [model ])
126- else :
127- write_line += f" |"
128-
129- if len (all_res ) > 0 :
130- write_line += f" { round (sum (all_res ) / len (all_res ), 2 )} |"
104+ write_line = f"| { model } |"
105+ all_res = []
106+ for type_name , results in task_type_res .items ():
107+ if model in results :
108+ write_line += f" { round (sum (results [model ]) / len (results [model ]), 2 )} |"
109+ all_res .extend (results [model ])
110+ else :
111+ write_line += f" |"
112+
113+ if len (all_res ) > 0 :
114+ write_line += f" { round (sum (all_res ) / len (all_res ), 2 )} |"
131115
132- f .write (write_line + ' \n ' )
116+ f .write (write_line + ' \n ' )
133117
134118
135119def get_args ():
136120 parser = argparse .ArgumentParser ()
137121 parser .add_argument ('--results_dir' , default = "./zh_results" , type = str )
138122 parser .add_argument ('--lang' , default = "zh" , type = str )
123+ parser .add_argument ('--model' , default = "model" , type = str )
139124 return parser .parse_args ()
140125
141126
142127if __name__ == '__main__' :
143128 args = get_args ()
144129
145- if args .lang == 'zh ' :
130+ if args .lang == 'zho ' :
146131 task_types = ["Retrieval" , "STS" , "PairClassification" , "Classification" , "Reranking" , "Clustering" ]
147- except_tasks = []
148- args .lang = ['zh' , 'zh-CN' ]
149- elif args .lang == 'en' :
132+ args .lang = ['zho' ]
133+ elif args .lang == 'eng' :
150134 task_types = ["Retrieval" , "Clustering" , "PairClassification" , "Reranking" , "STS" , "Summarization" ,
151135 "Classification" ]
152- except_tasks = ['MSMARCOv2' ]
153- args .lang = ['en' ]
136+ args .lang = ['eng' ]
154137 else :
155138 raise NotImplementedError (f"args.lang must be zh or en, but{ args .lang } " )
156139
157- task_results , model_dirs = read_results (task_types , except_tasks , args = args )
140+ task_results = read_results (task_types , args = args )
158141
159- output_markdown (task_results , model_dirs . keys () ,
142+ output_markdown (task_results , args . model ,
160143 save_file = os .path .join (args .results_dir , f'{ args .lang [0 ]} _results.md' ))
161144
162145
0 commit comments