Skip to content

Commit 56e2c8e

Browse files
committed
update beacon code
1 parent 888f34b commit 56e2c8e

37 files changed

Lines changed: 4987 additions & 2652 deletions

Long_LLM/activation_beacon/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
<h1>Soaring from 4K to 400K: Extending LLM's Context with Activation Beacon [<a href="https://arxiv.org/abs/2401.03462">paper</a>]</h1>
33
</div>
44

5-
This is the codebase for Activation Beacon, an effective, efficient, compatible, and low-cost (training) method to extend the context length of LLM by **x100** times. Currently we only apply activation beacon to [Llama-2-chat-7b](https://huggingface.co/namespace-Pt/activation-beacon-llama2-7b-chat) and [Mistral-7B-Instruct-v0.2](https://huggingface.co/namespace-Pt/activation-beacon-mistral-7b). More LLMs will be supported in the future.
6-
5+
This is the codebase for Activation Beacon, an effective, efficient, compatible, and low-cost (training) method to extend the context length of LLM through compressing KV cache.
76

87
## File structure:
98
- The [old](./old/) folder contains our initial implementation of Activation Beacon for Llama-2. You can use the code in it to reproduce the training/evaluation of the Llama-2 based model shown in our paper.
10-
- The [new](./new/) folder contains **newer** implementation of Activation Beacon for both Llama-2 and Mistral. It also supports more features, including **Deepspeed Zero3 training**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future.
9+
- The [new](./new/) folder contains **newer** implementation of Activation Beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future.

Long_LLM/activation_beacon/new/README.md

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
# Activation-Beacon
22

3-
This folder contains the newer code for activation beacon with the support of **Mistral models**, **Deepspeed Zero3 training**, **chat templates**, and **more evaluation tasks**. The code here are under development and subject to change in the future.
3+
[Activation Beacon](https://arxiv.org/abs/2401.03462) compresses the original KV into fewer yet more compact states (a.k.a. beacons) and hence enabling the LLM to perceive longer context given its fixed context window. It is known for the following features:
4+
- **Effective**
5+
- there is little information loss given a compression ratio of 2, 4, and 8;
6+
- **Efficient**
7+
- it drastically reduces the GPU consumption of KV cache;
8+
- **Compatible**
9+
- it can work together with position extrapolation (e.g. YaRN) to further extends the context length; it can also work with grouped query attention to further reduce the KV cache size;
10+
- **Low-Cost**
11+
- it is light-weight and can be efficiently trained with roughly 1B tokens.
12+
13+
This folder contains the newer code for activation beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future.
414

515
## Environment
616
```bash
717
conda create beacon python=3.10.14
818

919
conda activate beacon
1020

21+
# You may need to adjust the cuda version
1122
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
12-
pip install transformers==4.39.3 deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba
23+
pip install transformers==4.39.3 deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba python-Levenshtein
1324
pip install flash-attn --no-build-isolation
1425
```
1526

@@ -19,10 +30,15 @@ import json
1930
import torch
2031
from transformers import AutoModelForCausalLM, AutoTokenizer
2132

22-
model_id = "namespace-Pt/activation-beacon-mistral-7b"
33+
model_id = "namespace-Pt/beacon-qwen-2-7b-instruct"
2334

2435
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25-
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
36+
model = AutoModelForCausalLM.from_pretrained(
37+
model_id,
38+
trust_remote_code=True,
39+
torch_dtype=torch.bfloat16,
40+
attn_implementation="flash_attention_2"
41+
)
2642

2743
model = model.cuda().eval()
2844

@@ -32,7 +48,7 @@ with torch.no_grad():
3248
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
3349
outputs = model.generate(**inputs, max_new_tokens=50)
3450
print(f"Input Length: {inputs['input_ids'].shape[1]}")
35-
print(f"Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
51+
print(f"Output: {repr(tokenizer.decode(outputs[0], skip_special_tokens=True))}")
3652

3753
# reset memory before new generation task
3854
model.memory.reset()
@@ -55,16 +71,16 @@ with torch.no_grad():
5571
You should download the data for fine-tuning & evaluation then untar the file at anywhere you prefer, e.g. `/data`:
5672
```bash
5773
# feel free to alternate /data to your prefered location
58-
wget https://huggingface.co/datasets/namespace-Pt/projects/resolve/main/activation-beacon-new.tar.gz?download=true -O /data/activation-beacon-new.tar.gz
74+
wget https://huggingface.co/datasets/namespace-Pt/projects/resolve/main/long-llm.tar.gz?download=true -O /data/long-llm.tar.gz
5975

6076
cd /data
61-
tar -xzvf activation-beacon-new.tar.gz
77+
tar -xzvf long-llm.tar.gz
6278
```
6379

6480
**IMPORTANT NOTE**
6581

66-
For any path specified for `train_data` and `eval_data`: if it is prefixed with `activation-beacon:`, it will be solved to the relative path against [`data_root`](./src/args.py).
67-
- e.g. `activation-beacon:lm/pg19.json` becomes `${data_root}/lm/pg19.json`
82+
For any path specified for `train_data` and `eval_data`: if it is prefixed with `long-llm:`, it will be solved to the relative path against [`data_root`](./src/args.py).
83+
- e.g. `long-llm:lm/pg19.json` becomes `${data_root}/lm/pg19.json`
6884
- you can modify the default value of [`data_root`](./src/args.py), so that you don't need to type it for each command.
6985

7086

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,30 @@
11
# Evaluation
22

3-
## Prerequisite
4-
53
Make sure you have created the environment and downloaded the data according to [README](../README.md).
64

75

8-
## Evaluating Beacon Models
96
```bash
107
conda activate beacon
118

12-
model=namespace-Pt/activation-beacon-mistral-7b
9+
model=namespace-Pt/beacon-qwen-2-7b-instruct
1310

1411
# language modeling perplexity
1512
torchrun --nproc_per_node 8 -m main.eval_lm --max_length 100000 --stride 32768 --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024
1613

1714
# passkey retrieval accuracy
18-
torchrun --nproc_per_node 8 -m main.eval_passkey --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --chat_template mistral
15+
torchrun --nproc_per_node 8 -m main.eval_passkey --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024
1916

2017
# needle-in-a-haystack accuracy
21-
OPENAI_API_KEY="<you_api_key>" torchrun --nproc_per_node 8 -m main.eval_needle --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --chat_template mistral --gpt_eval
18+
OPENAI_API_KEY="<you_api_key>" torchrun --nproc_per_node 8 -m main.eval_needle --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --gpt_eval
2219

2320
# topic retrieval accuracy
24-
torchrun --nproc_per_node 8 -m main.eval_topic --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --chat_template mistral
21+
torchrun --nproc_per_node 8 -m main.eval_topic --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024
2522

2623
# longbench
27-
torchrun --nproc_per_node 8 -m main.eval_longbench --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --chat_template mistral
24+
torchrun --nproc_per_node 8 -m main.eval_longbench --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024
2825

2926
# infinitebench
30-
torchrun --nproc_per_node 8 -m main.eval_infbench --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024 --chat_template mistral
27+
torchrun --nproc_per_node 8 -m main.eval_infbench --model_name_or_path $model --enable_beacon --beacon_ratio_mix adapt-1024
3128
```
3229

3330
All evaluation results will be saved at `data/results`.
34-
35-
36-
37-
## Evaluating Full-Attention Models
38-
39-
Full-attention models cannot run with more than 32K context length on a single A800 GPU. Parallel strategies are required. We use [`tensor_parallel`](https://github.com/BlackSamorez/tensor_parallel). You should create anothr environtment while downgrade to `transformers==4.35.1` and install `tensor_parallel`:
40-
```bash
41-
conda create full --clone beacon
42-
pip install transformers==4.35.1 tensor_parallel
43-
```
44-
45-
Then, run the following commands: (feel free to switch `mistralai/Mistral-7B-Instruct-v0.2` to any models on huggingface)
46-
47-
```bash
48-
conda activate full
49-
50-
model=mistralai/Mistral-7B-Instruct-v0.2
51-
52-
# language modeling perplexity
53-
python -m main.eval_lm --max_length 100000 --stride 32768 --model_name_or_path $model --attn_impl flash_attention_2 --enable_tp
54-
55-
# passkey retrieval accuracy
56-
python -m main.eval_passkey --model_name_or_path $model --attn_impl flash_attention_2 --enable_tp --chat_template mistral
57-
58-
# needle-in-a-haystack accuracy
59-
OPENAI_API_KEY="<you_api_key>" python -m main.eval_needle --model_name_or_path $model --attn_impl flash_attention_2 --enable_tp --chat_template mistral --gpt_eval
60-
61-
# topic retrieval accuracy
62-
torchrun --nproc_per_node 8 -m main.eval_topic --model_name_or_path $model --attn_impl flash_attention_2 --chat_template mistral
63-
64-
# longbench
65-
torchrun --nproc_per_node 8 -m main.eval_longbench --model_name_or_path $model --attn_impl flash_attention_2 --chat_template mistral
66-
67-
# infbench
68-
python -m main.eval_infbench --model_name_or_path $model --attn_impl flash_attention_2 --chat_template mistral --enable_tp
69-
```

0 commit comments

Comments
 (0)