|
1 | | -<div align="center"> |
2 | | -<h1>Soaring from 4K to 400K: Extending LLM's Context with Activation Beacon [<a href="https://arxiv.org/abs/2401.03462">paper</a>]</h1> |
3 | | -</div> |
| 1 | +# Activation-Beacon |
4 | 2 |
|
5 | | -This is the codebase for Activation Beacon, an effective, efficient, compatible, and low-cost (training) method to extend the context length of LLM through compressing KV cache. |
| 3 | +[Activation Beacon](https://arxiv.org/abs/2401.03462) is a plug-in module to transformer-based LLMs that enables effective, efficient, and flexible compression of long contexts. |
6 | 4 |
|
7 | | -## File structure: |
8 | | -- The [old](./old/) folder contains our initial implementation of Activation Beacon for Llama-2. You can use the code in it to reproduce the training/evaluation of the Llama-2 based model shown in our paper. |
9 | | -- The [new](./new/) folder contains **newer** implementation of Activation Beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future. |
| 5 | +This folder contains the newer code for activation beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future. |
| 6 | + |
| 7 | +## Environment |
| 8 | +```bash |
| 9 | +conda create beacon python=3.10.14 |
| 10 | + |
| 11 | +conda activate beacon |
| 12 | + |
| 13 | +# You may need to adjust the cuda version |
| 14 | +conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia |
| 15 | +pip install transformers deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba python-Levenshtein |
| 16 | +pip install flash-attn --no-build-isolation |
| 17 | +``` |
| 18 | + |
| 19 | +## Usage |
| 20 | +```python |
| 21 | +import json |
| 22 | +import torch |
| 23 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 24 | + |
| 25 | +model_id = "namespace-Pt/beacon-qwen-2-7b-instruct" |
| 26 | + |
| 27 | +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| 28 | +model = AutoModelForCausalLM.from_pretrained( |
| 29 | + model_id, |
| 30 | + trust_remote_code=True, |
| 31 | + torch_dtype=torch.bfloat16, |
| 32 | + attn_implementation="flash_attention_2" |
| 33 | +) |
| 34 | + |
| 35 | +model = model.cuda().eval() |
| 36 | + |
| 37 | +with torch.no_grad(): |
| 38 | + # short context |
| 39 | + messages = [{"role": "user", "content": "Tell me about yourself."}] |
| 40 | + inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") |
| 41 | + outputs = model.generate(**inputs, max_new_tokens=50) |
| 42 | + print(f"Input Length: {inputs['input_ids'].shape[1]}") |
| 43 | + print(f"Output: {repr(tokenizer.decode(outputs[0], skip_special_tokens=True))}") |
| 44 | + |
| 45 | + # reset memory before new generation task |
| 46 | + model.memory.reset() |
| 47 | + |
| 48 | + # long context |
| 49 | + with open("data/toy/infbench.json", encoding="utf-8") as f: |
| 50 | + example = json.load(f) |
| 51 | + messages = [{"role": "user", "content": example["context"]}] |
| 52 | + inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") |
| 53 | + outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:] |
| 54 | + print("*"*20) |
| 55 | + print(f"Input Length: {inputs['input_ids'].shape[1]}") |
| 56 | + print(f"Answers: {example['answer']}") |
| 57 | + print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") |
| 58 | +``` |
| 59 | +**NOTE**: It's okay to see warnings like `This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (32768). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` Just ignore it. |
| 60 | + |
| 61 | + |
| 62 | +## Data |
| 63 | +You should download the data for fine-tuning & evaluation then untar the file at anywhere you prefer, e.g. `/data`: |
| 64 | +```bash |
| 65 | +# feel free to alternate /data to your prefered location |
| 66 | +wget https://huggingface.co/datasets/namespace-Pt/projects/resolve/main/long-llm.tar.gz?download=true -O /data/long-llm.tar.gz |
| 67 | + |
| 68 | +cd /data |
| 69 | +tar -xzvf long-llm.tar.gz |
| 70 | +``` |
| 71 | + |
| 72 | +**IMPORTANT NOTE** |
| 73 | + |
| 74 | +For any path specified for `train_data` and `eval_data`: if it is prefixed with `long-llm:`, it will be solved to the relative path against [`data_root`](./src/args.py). |
| 75 | + - e.g. `long-llm:lm/pg19.json` becomes `${data_root}/lm/pg19.json` |
| 76 | + - you can modify the default value of [`data_root`](./src/args.py), so that you don't need to type it for each command. |
| 77 | + |
| 78 | + |
| 79 | +## Training |
| 80 | +See [training section](./docs/training.md). |
| 81 | + |
| 82 | +## Evaluation |
| 83 | +See [evaluation section](./docs/evaluation.md). |
| 84 | + |
| 85 | + |
| 86 | +## Citation |
| 87 | +If you find this repository useful, please give us a star ⭐. |
| 88 | + |
| 89 | +To cite our work: |
| 90 | +``` |
| 91 | +@misc{zhang2024soaring, |
| 92 | + title={Soaring from 4K to 400K: Extending LLM's Context with Activation Beacon}, |
| 93 | + author={Peitian Zhang and Zheng Liu and Shitao Xiao and Ninglu Shao and Qiwei Ye and Zhicheng Dou}, |
| 94 | + year={2024}, |
| 95 | + eprint={2401.03462}, |
| 96 | + archivePrefix={arXiv}, |
| 97 | + primaryClass={cs.CL} |
| 98 | +} |
| 99 | +``` |
0 commit comments