Skip to content

Commit 6d7acef

Browse files
committed
optimize beacon speed (new)
1 parent 399d12f commit 6d7acef

11 files changed

Lines changed: 787 additions & 1065 deletions

File tree

Long_LLM/activation_beacon/new/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ conda activate beacon
2020

2121
# You may need to adjust the cuda version
2222
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
23-
pip install transformers==4.39.3 deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba python-Levenshtein
23+
pip install transformers deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba python-Levenshtein
2424
pip install flash-attn --no-build-isolation
2525
```
2626

Long_LLM/activation_beacon/new/src/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .chat import apply_chat_template
33
from .args import ModelArgs
44
from .data import Data
5-
from .modeling_utils import evaluate_perplexity, evaluate_generation, evaluate_nll, move_to_device
5+
from .modeling_utils import evaluate_perplexity, evaluate_generation, evaluate_nll, move_to_device, get_shifted_labels
66

77
import logging
88
logging.basicConfig(
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
def get_model_and_tokenizer(model_args, device="cpu", evaluation_mode=True, return_tokenizer_only=False, **kwargs):
15+
def get_model_and_tokenizer(model_args, device="cpu", evaluation_mode=True, return_tokenizer_only=False, **kwargs):
1616
import torch
1717
import transformers
1818
from dataclasses import asdict
@@ -97,8 +97,6 @@ def get_model_and_tokenizer(model_args, device="cpu", evaluation_mode=True, retu
9797
for k, v in model_args_dict.items():
9898
if k.startswith("beacon") and v is not None:
9999
beacon_kwargs[k] = v
100-
elif k.startswith("retrieval") and v is not None:
101-
beacon_kwargs[k] = v
102100

103101
# use architecture attribute to distinguish different models
104102
probe_config = AutoConfig.from_pretrained(

Long_LLM/activation_beacon/new/src/args.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ModelArgs:
2929
)
3030

3131
model_name_or_path: str = field(
32-
default='meta-llama/Llama-2-7b-chat-hf',
32+
default='Qwen/Qwen2-7B-Instruct',
3333
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
3434
)
3535
padding_side: str = field(
@@ -161,7 +161,7 @@ class ModelArgs:
161161
)
162162
beacon_param: Optional[List[str]] = field(
163163
default=None,
164-
metadata={'help': 'The introduced parameters for beacon.'}
164+
metadata={'help': 'The introduced parameters for beacon. {q, k, v, o}'}
165165
)
166166
beacon_embed_init: str = field(
167167
default="eos",
@@ -183,18 +183,6 @@ class ModelArgs:
183183
default=None,
184184
metadata={'help': 'How many windows to run in parallel?'}
185185
)
186-
retrieval_method: Optional[str] = field(
187-
default=None,
188-
metadata={'help': 'How to retrieve? {bm25}'}
189-
)
190-
retrieval_topk: Optional[int] = field(
191-
default=None,
192-
metadata={'help': 'How many windows to retrieve?'}
193-
)
194-
retrieval_key_length: Optional[int] = field(
195-
default=None,
196-
metadata={'help': 'The key sequence length in retrieval.'}
197-
)
198186

199187
max_new_tokens: Optional[int] = field(
200188
default=None,

Long_LLM/activation_beacon/new/src/data.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
logger = logging.get_logger(__name__)
1515

1616

17-
# RETRIEVAL_CAND = [(1024,1), (512,2), (256,4), (128,8), (512,1), (256,2), (128,4)]
18-
RETRIEVAL_CAND = [(1024,1)]
19-
2017

2118
class Data:
19+
def _process_pretrain_data(data, indices):
20+
outputs = {"labels": [], "index": [], "length": []}
21+
for input_ids, index in zip(data['input_ids'], indices):
22+
outputs["index"].append(index)
23+
outputs["length"].append(len(input_ids))
24+
# NOTE: the labels will be automatically generated in Trainer._prepare_inputs
25+
outputs["labels"].append(None)
26+
return outputs
27+
2228
def _process_language_modeling(data, indices, tokenizer, min_length, max_length):
23-
outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}
29+
outputs = {'input_ids': [], "labels": [], "length": [], "index": []}
2430

2531
for i, text in enumerate(data['text']):
2632
# truncate text for faster processing
@@ -33,18 +39,20 @@ def _process_language_modeling(data, indices, tokenizer, min_length, max_length)
3339
for k, v in encoded.items():
3440
encoded[k] = v[:max_length]
3541

36-
encoded["labels"] = encoded["input_ids"].copy()
42+
# NOTE: the labels will be automatically generated in Trainer._prepare_inputs
43+
encoded["labels"] = None
3744

3845
for k, v in encoded.items():
39-
outputs[k].append(v)
46+
if k in outputs:
47+
outputs[k].append(v)
4048
# length is required for grouping
4149
outputs["length"].append(len(encoded['input_ids']))
4250
outputs["index"].append(indices[i])
4351

4452
return outputs
4553

4654
def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_length, max_length, eval_mode=False):
47-
outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}
55+
outputs = {'input_ids': [], "labels": [], "length": [], "index": []}
4856

4957
for i, source in enumerate(data['conversations']):
5058
if source[0]["role"] != 'user':
@@ -69,6 +77,11 @@ def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_len
6977
add_generation_prompt=eval_mode,
7078
).encoded
7179

80+
# NOTE: shift the labels in advance
81+
# labels = encoded["labels"][1:]
82+
# labels.append(-100)
83+
# encoded["labels"] = labels
84+
7285
# skip data that not fall in between min_length and max_length
7386
if min_length is not None and len(encoded["input_ids"]) < min_length:
7487
continue
@@ -79,13 +92,14 @@ def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_len
7992
encoded["labels"] = labels
8093

8194
for k, v in encoded.items():
82-
outputs[k].append(v)
95+
if k in outputs:
96+
outputs[k].append(v)
8397
outputs['length'].append(len(encoded['input_ids']))
8498
outputs['index'].append(indices[i])
8599

86100
return outputs
87101

88-
def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", seed=42, cache_dir=None, load_from_cache_file=None):
102+
def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", seed=42, cache_dir=None, load_from_cache_file=None, ignore_index=False, ignore_length=False):
89103
if data_files is None:
90104
return None
91105

@@ -115,6 +129,7 @@ def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_len
115129
if os.path.isdir(data_file) and os.path.exists(os.path.join(data_file, "dataset_info.json")):
116130
# the dataset may be save_to_disk in advance
117131
dataset = datasets.load_from_disk(data_file)
132+
dataset = dataset.map(Data._process_pretrain_data, batched=True, num_proc=32, batch_size=32, with_indices=True)
118133

119134
else:
120135
# the dataset is a json file
@@ -145,16 +160,18 @@ def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_len
145160
dataset = dataset.train_test_split(max_sample_num, seed=seed)["test"]
146161

147162
# index column is useless in training
148-
if "index" in dataset.column_names:
163+
if "index" in dataset.column_names and ignore_index:
149164
dataset = dataset.remove_columns(["index"])
165+
if "length" in dataset.column_names and ignore_length:
166+
dataset = dataset.remove_columns(["length"])
150167

151168
train_datasets.append(dataset)
152169

153170
dataset = datasets.concatenate_datasets(train_datasets)
154171

155172
return dataset
156173

157-
def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_eval_num=None, cache_dir=None, seed=42, load_from_cache_file=None):
174+
def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_eval_num=None, cache_dir=None, seed=42, load_from_cache_file=None, ignore_index=False, ignore_length=False):
158175
if data_files is None:
159176
return None
160177

@@ -186,4 +203,9 @@ def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_leng
186203
raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!")
187204

188205
dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True, load_from_cache_file=load_from_cache_file)
206+
if "index" in dataset.column_names and ignore_index:
207+
dataset = dataset.remove_columns(["index"])
208+
if "length" in dataset.column_names and ignore_length:
209+
dataset = dataset.remove_columns(["length"])
210+
189211
return dataset

0 commit comments

Comments
 (0)