|
1 | 1 | { |
2 | 2 | "cells": [ |
| 3 | + { |
| 4 | + "metadata": {}, |
| 5 | + "cell_type": "code", |
| 6 | + "outputs": [], |
| 7 | + "execution_count": null, |
| 8 | + "source": [ |
| 9 | + "import torch\n", |
| 10 | + "from torch.optim import Adam\n", |
| 11 | + "from torch.utils.data import DataLoader, TensorDataset\n", |
| 12 | + "from torch.utils.data import random_split\n", |
| 13 | + "from transformers import AutoTokenizer\n", |
| 14 | + "\n", |
| 15 | + "from labml import tracker, experiment\n", |
| 16 | + "from labml_nn.lora.gpt2 import GPTModel" |
| 17 | + ], |
| 18 | + "id": "f072832ec9d346e1" |
| 19 | + }, |
3 | 20 | { |
4 | 21 | "cell_type": "code", |
5 | 22 | "id": "initial_id", |
|
29 | 46 | "id": "ac8e51ae5bbfcae7", |
30 | 47 | "metadata": {}, |
31 | 48 | "source": [ |
32 | | - "from transformers import AutoTokenizer\n", |
33 | | - "\n", |
34 | 49 | "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", |
35 | 50 | "\n", |
36 | 51 | "tokens = tokenizer.encode(text, add_special_tokens=False)" |
|
64 | 79 | "cell_type": "code", |
65 | 80 | "id": "5c4cc78ac1a02c1d", |
66 | 81 | "metadata": {}, |
67 | | - "source": [ |
68 | | - "import torch\n", |
69 | | - "\n", |
70 | | - "input_ids = torch.tensor(tokens).view(-1, context_length)" |
71 | | - ], |
| 82 | + "source": "input_ids = torch.tensor(tokens).view(-1, context_length)", |
72 | 83 | "outputs": [], |
73 | 84 | "execution_count": null |
74 | 85 | }, |
|
77 | 88 | "id": "7037fd75e2161382", |
78 | 89 | "metadata": {}, |
79 | 90 | "source": [ |
80 | | - "from torch.utils.data import DataLoader, TensorDataset\n", |
81 | | - "from torch.optim import Adam\n", |
82 | | - "from torch.utils.data import random_split\n", |
83 | | - "\n", |
84 | 91 | "dataset = TensorDataset(input_ids)\n", |
85 | 92 | "\n", |
86 | 93 | "train_ratio = 0.8\n", |
|
102 | 109 | "id": "a98b7baa064b8494", |
103 | 110 | "metadata": {}, |
104 | 111 | "source": [ |
105 | | - "from labml_nn.transformers.LoRA.GPT2 import GPTModel\n", |
106 | | - "\n", |
107 | 112 | "model = GPTModel()\n", |
108 | 113 | "state_dict = torch.load('transformed.pth', weights_only=True)\n", |
109 | 114 | "\n", |
|
128 | 133 | "id": "e2f5076894770740", |
129 | 134 | "metadata": {}, |
130 | 135 | "source": [ |
131 | | - "from labml import tracker, experiment\n", |
132 | | - "\n", |
133 | 136 | "optimizer = Adam(model.parameters(), lr=5e-5)\n", |
134 | 137 | "criterion = torch.nn.CrossEntropyLoss()\n", |
135 | 138 | "\n", |
|
143 | 146 | " inputs = batch[0]\n", |
144 | 147 | " inputs = inputs.to(device)\n", |
145 | 148 | " labels = inputs.clone()\n", |
146 | | - " \n", |
| 149 | + "\n", |
147 | 150 | " outputs = model(inputs)\n", |
148 | | - " \n", |
| 151 | + "\n", |
149 | 152 | " shift_logits = outputs[..., :-1, :]\n", |
150 | 153 | " shift_labels = labels[..., 1:]\n", |
151 | | - " \n", |
| 154 | + "\n", |
152 | 155 | " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n", |
153 | | - " \n", |
| 156 | + "\n", |
154 | 157 | " optimizer.zero_grad()\n", |
155 | 158 | " loss.backward()\n", |
156 | 159 | " optimizer.step()\n", |
157 | | - " \n", |
| 160 | + "\n", |
158 | 161 | " tracker.save(step, {'loss': loss})\n", |
159 | 162 | " step += 1\n", |
160 | 163 | " print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n", |
161 | | - " \n", |
| 164 | + "\n", |
162 | 165 | " test_loss = 0\n", |
163 | 166 | " for batch in test_dataloader:\n", |
164 | 167 | " inputs = batch[0]\n", |
165 | 168 | " inputs = inputs.to(device)\n", |
166 | 169 | " labels = inputs.clone()\n", |
167 | | - " \n", |
| 170 | + "\n", |
168 | 171 | " outputs = model(inputs)\n", |
169 | | - " \n", |
| 172 | + "\n", |
170 | 173 | " shift_logits = outputs[..., :-1, :]\n", |
171 | 174 | " shift_labels = labels[..., 1:]\n", |
172 | | - " \n", |
| 175 | + "\n", |
173 | 176 | " loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n", |
174 | | - " \n", |
| 177 | + "\n", |
175 | 178 | " test_loss += loss.item()\n", |
176 | 179 | " test_loss /= len(test_dataloader)\n", |
177 | 180 | " tracker.save(step, {'test_loss': test_loss})\n", |
178 | | - " \n", |
179 | 181 | "\n", |
180 | 182 | "print(\"Training complete.\")" |
181 | 183 | ], |
|
0 commit comments