Skip to content

Commit a08f352

Browse files
committed
refactor __main__.py for finetune
1 parent 974ce9a commit a08f352

9 files changed

Lines changed: 139 additions & 97 deletions

File tree

FlagEmbedding/finetune/embedder/decoder_only/base/__main__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,24 @@
88
)
99

1010

11-
parser = HfArgumentParser((
12-
DecoderOnlyEmbedderModelArguments,
13-
DecoderOnlyEmbedderDataArguments,
14-
DecoderOnlyEmbedderTrainingArguments
15-
))
16-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
17-
model_args: DecoderOnlyEmbedderModelArguments
18-
data_args: DecoderOnlyEmbedderDataArguments
19-
training_args: DecoderOnlyEmbedderTrainingArguments
11+
def main():
12+
parser = HfArgumentParser((
13+
DecoderOnlyEmbedderModelArguments,
14+
DecoderOnlyEmbedderDataArguments,
15+
DecoderOnlyEmbedderTrainingArguments
16+
))
17+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
18+
model_args: DecoderOnlyEmbedderModelArguments
19+
data_args: DecoderOnlyEmbedderDataArguments
20+
training_args: DecoderOnlyEmbedderTrainingArguments
2021

21-
runner = DecoderOnlyEmbedderRunner(
22-
model_args=model_args,
23-
data_args=data_args,
24-
training_args=training_args
25-
)
26-
runner.run()
22+
runner = DecoderOnlyEmbedderRunner(
23+
model_args=model_args,
24+
data_args=data_args,
25+
training_args=training_args
26+
)
27+
runner.run()
28+
29+
30+
if __name__ == "__main__":
31+
main()

FlagEmbedding/finetune/embedder/decoder_only/icl/__main__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,24 @@
88
)
99

1010

11-
parser = HfArgumentParser((
12-
DecoderOnlyEmbedderICLModelArguments,
13-
DecoderOnlyEmbedderICLDataArguments,
14-
DecoderOnlyEmbedderICLTrainingArguments
15-
))
16-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
17-
model_args: DecoderOnlyEmbedderICLModelArguments
18-
data_args: DecoderOnlyEmbedderICLDataArguments
19-
training_args: DecoderOnlyEmbedderICLTrainingArguments
11+
def main():
12+
parser = HfArgumentParser((
13+
DecoderOnlyEmbedderICLModelArguments,
14+
DecoderOnlyEmbedderICLDataArguments,
15+
DecoderOnlyEmbedderICLTrainingArguments
16+
))
17+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
18+
model_args: DecoderOnlyEmbedderICLModelArguments
19+
data_args: DecoderOnlyEmbedderICLDataArguments
20+
training_args: DecoderOnlyEmbedderICLTrainingArguments
2021

21-
runner = DecoderOnlyEmbedderICLRunner(
22-
model_args=model_args,
23-
data_args=data_args,
24-
training_args=training_args
25-
)
26-
runner.run()
22+
runner = DecoderOnlyEmbedderICLRunner(
23+
model_args=model_args,
24+
data_args=data_args,
25+
training_args=training_args
26+
)
27+
runner.run()
28+
29+
30+
if __name__ == "__main__":
31+
main()

FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,24 @@
88
)
99

1010

11-
parser = HfArgumentParser((
12-
EncoderOnlyEmbedderModelArguments,
13-
EncoderOnlyEmbedderDataArguments,
14-
EncoderOnlyEmbedderTrainingArguments
15-
))
16-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
17-
model_args: EncoderOnlyEmbedderModelArguments
18-
data_args: EncoderOnlyEmbedderDataArguments
19-
training_args: EncoderOnlyEmbedderTrainingArguments
11+
def main():
12+
parser = HfArgumentParser((
13+
EncoderOnlyEmbedderModelArguments,
14+
EncoderOnlyEmbedderDataArguments,
15+
EncoderOnlyEmbedderTrainingArguments
16+
))
17+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
18+
model_args: EncoderOnlyEmbedderModelArguments
19+
data_args: EncoderOnlyEmbedderDataArguments
20+
training_args: EncoderOnlyEmbedderTrainingArguments
2021

21-
runner = EncoderOnlyEmbedderRunner(
22-
model_args=model_args,
23-
data_args=data_args,
24-
training_args=training_args
25-
)
26-
runner.run()
22+
runner = EncoderOnlyEmbedderRunner(
23+
model_args=model_args,
24+
data_args=data_args,
25+
training_args=training_args
26+
)
27+
runner.run()
28+
29+
30+
if __name__ == "__main__":
31+
main()

FlagEmbedding/finetune/embedder/encoder_only/m3/__main__.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,20 @@
88
)
99

1010

11-
parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments))
12-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
13-
model_args: EncoderOnlyEmbedderM3ModelArguments
14-
data_args: EncoderOnlyEmbedderM3DataArguments
15-
training_args: EncoderOnlyEmbedderM3TrainingArguments
11+
def main():
12+
parser = HfArgumentParser((EncoderOnlyEmbedderM3ModelArguments, EncoderOnlyEmbedderM3DataArguments, EncoderOnlyEmbedderM3TrainingArguments))
13+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
14+
model_args: EncoderOnlyEmbedderM3ModelArguments
15+
data_args: EncoderOnlyEmbedderM3DataArguments
16+
training_args: EncoderOnlyEmbedderM3TrainingArguments
1617

17-
runner = EncoderOnlyEmbedderM3Runner(
18-
model_args=model_args,
19-
data_args=data_args,
20-
training_args=training_args
21-
)
22-
runner.run()
18+
runner = EncoderOnlyEmbedderM3Runner(
19+
model_args=model_args,
20+
data_args=data_args,
21+
training_args=training_args
22+
)
23+
runner.run()
24+
25+
26+
if __name__ == "__main__":
27+
main()

FlagEmbedding/finetune/reranker/decoder_only/base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
__all__ = [
77
"CrossDecoderModel",
88
"DecoderOnlyRerankerRunner",
9-
"DecoderOnlyRerankerTrainer"
9+
"DecoderOnlyRerankerTrainer",
10+
"RerankerModelArguments",
1011
]

FlagEmbedding/finetune/reranker/decoder_only/base/__main__.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,26 @@
55
AbsRerankerTrainingArguments
66
)
77

8-
from FlagEmbedding.finetune.reranker.decoder_only.base.runner import DecoderOnlyRerankerRunner
9-
from FlagEmbedding.finetune.reranker.decoder_only.base.arguments import RerankerModelArguments
8+
from FlagEmbedding.finetune.reranker.decoder_only.base import (
9+
DecoderOnlyRerankerRunner,
10+
RerankerModelArguments
11+
)
1012

11-
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
12-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
13-
model_args: RerankerModelArguments
14-
data_args: AbsRerankerDataArguments
15-
training_args: AbsRerankerTrainingArguments
1613

17-
runner = DecoderOnlyRerankerRunner(
18-
model_args=model_args,
19-
data_args=data_args,
20-
training_args=training_args
21-
)
22-
runner.run()
14+
def main():
15+
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
16+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
17+
model_args: RerankerModelArguments
18+
data_args: AbsRerankerDataArguments
19+
training_args: AbsRerankerTrainingArguments
20+
21+
runner = DecoderOnlyRerankerRunner(
22+
model_args=model_args,
23+
data_args=data_args,
24+
training_args=training_args
25+
)
26+
runner.run()
27+
28+
29+
if __name__ == "__main__":
30+
main()

FlagEmbedding/finetune/reranker/decoder_only/layerwise/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
__all__ = [
77
"CrossDecoderModel",
88
"DecoderOnlyRerankerRunner",
9-
"DecoderOnlyRerankerTrainer"
9+
"DecoderOnlyRerankerTrainer",
10+
"RerankerModelArguments",
1011
]
Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
from transformers import HfArgumentParser
22

33
from FlagEmbedding.abc.finetune.reranker import (
4-
AbsRerankerModelArguments,
54
AbsRerankerDataArguments,
65
AbsRerankerTrainingArguments
76
)
87

9-
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.runner import DecoderOnlyRerankerRunner
10-
from FlagEmbedding.finetune.reranker.decoder_only.layerwise.arguments import RerankerModelArguments
8+
from FlagEmbedding.finetune.reranker.decoder_only.layerwise import (
9+
DecoderOnlyRerankerRunner,
10+
RerankerModelArguments
11+
)
1112

12-
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
13-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
14-
model_args: RerankerModelArguments
15-
data_args: AbsRerankerDataArguments
16-
training_args: AbsRerankerTrainingArguments
1713

18-
runner = DecoderOnlyRerankerRunner(
19-
model_args=model_args,
20-
data_args=data_args,
21-
training_args=training_args
22-
)
23-
runner.run()
14+
def main():
15+
parser = HfArgumentParser((RerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
16+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
17+
model_args: RerankerModelArguments
18+
data_args: AbsRerankerDataArguments
19+
training_args: AbsRerankerTrainingArguments
20+
21+
runner = DecoderOnlyRerankerRunner(
22+
model_args=model_args,
23+
data_args=data_args,
24+
training_args=training_args
25+
)
26+
runner.run()
27+
28+
29+
if __name__ == "__main__":
30+
main()

FlagEmbedding/finetune/reranker/encoder_only/base/__main__.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@
55
AbsRerankerDataArguments,
66
AbsRerankerTrainingArguments
77
)
8-
from FlagEmbedding.finetune.reranker.encoder_only.base.runner import EncoderOnlyRerankerRunner
8+
from FlagEmbedding.finetune.reranker.encoder_only.base import EncoderOnlyRerankerRunner
99

1010

11-
parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
12-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
13-
model_args: AbsRerankerModelArguments
14-
data_args: AbsRerankerDataArguments
15-
training_args: AbsRerankerTrainingArguments
11+
def main():
12+
parser = HfArgumentParser((AbsRerankerModelArguments, AbsRerankerDataArguments, AbsRerankerTrainingArguments))
13+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
14+
model_args: AbsRerankerModelArguments
15+
data_args: AbsRerankerDataArguments
16+
training_args: AbsRerankerTrainingArguments
1617

17-
runner = EncoderOnlyRerankerRunner(
18-
model_args=model_args,
19-
data_args=data_args,
20-
training_args=training_args
21-
)
22-
runner.run()
18+
runner = EncoderOnlyRerankerRunner(
19+
model_args=model_args,
20+
data_args=data_args,
21+
training_args=training_args
22+
)
23+
runner.run()
24+
25+
26+
if __name__ == "__main__":
27+
main()

0 commit comments

Comments
 (0)