Skip to content

Commit 96725ac

Browse files
committed
update readme
1 parent 1054f1e commit 96725ac

50 files changed

Lines changed: 788 additions & 21 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

research/LM_Cocktail/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ The merged model can be used to perform multiple tasks.
4949
Install the latest version from source (Recommended):
5050
```bash
5151
git clone https://github.com/FlagOpen/FlagEmbedding.git
52-
cd FlagEmbedding/LM_Cocktail
52+
cd FlagEmbedding/research/LM_Cocktail
5353
pip install -e .
5454
```
5555
Install by pip:
@@ -260,6 +260,7 @@ torchrun --nproc_per_node 8 -m evaluation.eval_mmlu \
260260
- Models: we fine-tune the [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on 9 tasks, and you can find the fine-tuned models at this [link](https://huggingface.co/Shitao).
261261
- Examples Data: [./embedder_examples.json]()
262262

263+
263264
Use [MTEB script](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB) to evaluate the mixed embedding model:
264265
```bash
265266
python eval_MTEB.py --model_name_or_path mixed_model --task_type Retrieval

research/llm_embedder/README.md

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,51 @@ This is the codebase for LLM-Embedder, a unified embedding model to comprehensiv
2020
### Using `FlagEmbedding`
2121
```pip install -U FlagEmbedding```
2222
```python
23-
from FlagEmbedding import LLMEmbedder
23+
from FlagEmbedding import FlagModel
24+
25+
INSTRUCTIONS = {
26+
"qa": {
27+
"query": "Represent this query for retrieving relevant documents: ",
28+
"key": "Represent this document for retrieval: ",
29+
},
30+
"icl": {
31+
"query": "Convert this example into vector to look for useful examples: ",
32+
"key": "Convert this example into vector for retrieval: ",
33+
},
34+
"chat": {
35+
"query": "Embed this dialogue to find useful historical dialogues: ",
36+
"key": "Embed this historical dialogue for retrieval: ",
37+
},
38+
"lrlm": {
39+
"query": "Embed this text chunk for finding useful historical chunks: ",
40+
"key": "Embed this historical text chunk for retrieval: ",
41+
},
42+
"tool": {
43+
"query": "Transform this user request for fetching helpful tool descriptions: ",
44+
"key": "Transform this tool description for retrieval: "
45+
},
46+
"convsearch": {
47+
"query": "Encode this query and context for searching relevant passages: ",
48+
"key": "Encode this passage for retrieval: ",
49+
},
50+
}
2451

2552
# Define queries and keys
2653
queries = ["test query 1", "test query 2"]
2754
keys = ["test key 1", "test key 2"]
2855

29-
# Load model (automatically use GPUs)
30-
model = LLMEmbedder('BAAI/llm-embedder', use_fp16=False)
31-
3256
# Encode for a specific task (qa, icl, chat, lrlm, tool, convsearch)
3357
task = "qa"
34-
query_embeddings = model.encode_queries(queries, task=task)
35-
key_embeddings = model.encode_keys(keys, task=task)
58+
59+
# Load model (automatically use GPUs)
60+
model = FlagModel('BAAI/llm-embedder',
61+
use_fp16=False,
62+
query_instruction_for_retrieval=INSTRUCTIONS[task]['query'],
63+
passage_instruction_for_retrieval=INSTRUCTIONS[task]['key'],
64+
devices=['cuda:0'])
65+
66+
query_embeddings = model.encode_queries(queries)
67+
key_embeddings = model.encode_corpus(keys)
3668

3769
similarity = query_embeddings @ key_embeddings.T
3870
print(similarity)

research/old-examples/pretrain/README.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,9 @@ pip install -U FlagEmbedding
1111
* **from source**
1212
```
1313
git clone https://github.com/FlagOpen/FlagEmbedding.git
14-
cd FlagEmbedding
15-
pip install .
16-
```
17-
For development, install as editable:
18-
```
19-
pip install -e .
14+
cd FlagEmbedding/research/old-examples/pretrain
2015
```
2116

22-
2317
## 2. Data format
2418
Train data should be a json file, where each line is a dict like this:
2519
```
@@ -31,7 +25,7 @@ See [toy_pretrain_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/mas
3125

3226
```bash
3327
torchrun --nproc_per_node {number of gpus} \
34-
-m FlagEmbedding.baai_general_embedding.retromae_pretrain.run \
28+
-m retromae_pretrain.run \
3529
--output_dir {path to save model} \
3630
--model_name_or_path BAAI/bge-large-en \
3731
--train_data toy_pretrain_data.jsonl \
@@ -47,4 +41,3 @@ torchrun --nproc_per_node {number of gpus} \
4741
More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments).
4842
After training, the encoder model will saved to `{output_dir}/encoder_model`
4943

50-
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
from dataclasses import dataclass, field
3+
from typing import Optional
4+
5+
6+
@dataclass
7+
class DataTrainingArguments:
8+
train_data: Optional[str] = field(
9+
default=None, metadata={"help": "Path to pretrain data"}
10+
)
11+
tokenizer_name: Optional[str] = field(
12+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
13+
)
14+
max_seq_length: Optional[int] = field(
15+
default=512,
16+
metadata={
17+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
18+
"than this will be truncated. Default to the max input length of the model."
19+
},
20+
)
21+
encoder_mlm_probability: float = field(default=0.3, metadata={"help": "mask ratio for encoder"})
22+
decoder_mlm_probability: float = field(default=0.5, metadata={"help": "mask ratio for decoder"})
23+
24+
def __post_init__(self):
25+
if not os.path.exists(self.train_data):
26+
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
27+
28+
29+
@dataclass
30+
class ModelArguments:
31+
"""
32+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
33+
"""
34+
model_name_or_path: Optional[str] = field(
35+
default='bert-base-uncased',
36+
metadata={
37+
"help": "The model checkpoint for weights initialization."
38+
"Don't set if you want to train a model from scratch."
39+
},
40+
)
41+
config_name: Optional[str] = field(
42+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
43+
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import random
3+
from copy import deepcopy
4+
from dataclasses import dataclass
5+
6+
import torch.utils.data.dataset
7+
from datasets import Dataset, load_dataset, concatenate_datasets
8+
from transformers import DataCollatorForWholeWordMask
9+
10+
from .utils import tensorize_batch
11+
12+
13+
class DatasetForPretraining(torch.utils.data.Dataset):
14+
def __init__(self, data_dir):
15+
if os.path.isdir(data_dir):
16+
datasets = []
17+
for file in os.listdir(data_dir):
18+
print(f"Loading {file}")
19+
file = os.path.join(data_dir, file)
20+
datasets.append(self.load_dataset(file))
21+
self.dataset = concatenate_datasets(datasets)
22+
else:
23+
print(f"Loading {data_dir}")
24+
self.dataset = self.load_dataset(data_dir)
25+
26+
def load_dataset(self, file):
27+
if file.endswith('.jsonl') or file.endswith('.json'):
28+
return load_dataset('json', data_files=file)['train']
29+
elif os.path.isdir(file):
30+
return Dataset.load_from_disk(file)
31+
else:
32+
raise NotImplementedError(f"Not support this file format:{file}")
33+
34+
def __getitem__(self, item):
35+
return self.dataset[item]['text']
36+
37+
def __len__(self):
38+
return len(self.dataset)
39+
40+
41+
@dataclass
42+
class RetroMAECollator(DataCollatorForWholeWordMask):
43+
max_seq_length: int = 512
44+
encoder_mlm_probability: float = 0.15
45+
decoder_mlm_probability: float = 0.15
46+
47+
def __call__(self, examples):
48+
input_ids_batch = []
49+
attention_mask_batch = []
50+
encoder_mlm_mask_batch = []
51+
decoder_labels_batch = []
52+
decoder_matrix_attention_mask_batch = []
53+
54+
for e in examples:
55+
56+
e_trunc = self.tokenizer.encode(e, max_length=self.max_seq_length, truncation=True)
57+
tokens = [self.tokenizer._convert_id_to_token(tid) for tid in e_trunc]
58+
59+
self.mlm_probability = self.encoder_mlm_probability
60+
text_encoder_mlm_mask = self._whole_word_mask(tokens)
61+
62+
self.mlm_probability = self.decoder_mlm_probability
63+
mask_set = []
64+
for _ in range(min(len(tokens), 128)):
65+
mask_set.append(self._whole_word_mask(tokens))
66+
67+
text_matrix_attention_mask = []
68+
for i in range(len(tokens)):
69+
idx = random.randint(0, min(len(tokens), 128) - 1)
70+
text_decoder_mlm_mask = deepcopy(mask_set[idx])
71+
text_decoder_mlm_mask[i] = 1
72+
text_matrix_attention_mask.append(text_decoder_mlm_mask)
73+
74+
input_ids_batch.append(torch.tensor(e_trunc))
75+
attention_mask_batch.append(torch.tensor([1] * len(e_trunc)))
76+
e_trunc[0] = -100
77+
e_trunc[-1] = -100
78+
decoder_labels_batch.append(torch.tensor(e_trunc))
79+
80+
encoder_mlm_mask_batch.append(torch.tensor(text_encoder_mlm_mask))
81+
decoder_matrix_attention_mask_batch.append(1 - torch.tensor(text_matrix_attention_mask))
82+
83+
input_ids_batch = tensorize_batch(input_ids_batch, self.tokenizer.pad_token_id)
84+
attention_mask_batch = tensorize_batch(attention_mask_batch, 0)
85+
origin_input_ids_batch = input_ids_batch.clone()
86+
encoder_mlm_mask_batch = tensorize_batch(encoder_mlm_mask_batch, 0)
87+
encoder_input_ids_batch, encoder_labels_batch = self.torch_mask_tokens(input_ids_batch, encoder_mlm_mask_batch)
88+
decoder_labels_batch = tensorize_batch(decoder_labels_batch, -100)
89+
matrix_attention_mask_batch = tensorize_batch(decoder_matrix_attention_mask_batch, 0)
90+
91+
batch = {
92+
"encoder_input_ids": encoder_input_ids_batch,
93+
"encoder_attention_mask": attention_mask_batch,
94+
"encoder_labels": encoder_labels_batch,
95+
"decoder_input_ids": origin_input_ids_batch,
96+
"decoder_attention_mask": matrix_attention_mask_batch, # [B,L,L]
97+
"decoder_labels": decoder_labels_batch,
98+
}
99+
100+
return batch

0 commit comments

Comments
 (0)