|
106 | 106 | "id": "8c8f59a6c426" |
107 | 107 | }, |
108 | 108 | "source": [ |
109 | | - "### Prerequisites\n", |
| 109 | + "## 0: Prerequisites\n", |
110 | 110 | "\n", |
111 | 111 | "- Install AlphaGenome Research and Google Cloud Platform packages.\n", |
112 | 112 | "- (*) Choose either H100 or A100 specific vm notebook runtime.\n", |
113 | 113 | "- (**) Save Huggingface credentials in Google Cloud Secret manager\n", |
| 114 | + "- (***) Install 0.9.0 jax libraries\n", |
114 | 115 | "\n", |
115 | 116 | "Notebook launch:\n", |
116 | 117 | "- Launch the Notebook in Google Cloud Enterprise Colab.\n", |
|
130 | 131 | "You will be downloading weights from Huggingface.\n", |
131 | 132 | "Ensure that:\n", |
132 | 133 | "- You create a token that has 'Read access to contents of all public gated repos you can access' (under Finegrained control)\n", |
133 | | - "- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0)." |
| 134 | + "- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0).\n", |
| 135 | + "\n", |
| 136 | + "\n", |
| 137 | + "(***):\n", |
| 138 | + "Upgrading to 0.9.0 will require multiple runtime restarts." |
134 | 139 | ] |
135 | 140 | }, |
136 | 141 | { |
|
151 | 156 | { |
152 | 157 | "cell_type": "code", |
153 | 158 | "execution_count": null, |
154 | | - "id": "ictFNXdeQ4Cf", |
155 | 159 | "metadata": { |
156 | | - "id": "2bdbb159823b" |
| 160 | + "id": "834dcef76adf" |
| 161 | + }, |
| 162 | + "outputs": [], |
| 163 | + "source": [ |
| 164 | + "import jax\n", |
| 165 | + "# We need >0.9.0 jax libs.\n", |
| 166 | + "# Check the jax version.\n", |
| 167 | + "import jaxlib\n", |
| 168 | + "\n", |
| 169 | + "print(f\"{jax.__version__=}\")\n", |
| 170 | + "print(f\"{jaxlib.__version__=}\")" |
| 171 | + ] |
| 172 | + }, |
| 173 | + { |
| 174 | + "cell_type": "code", |
| 175 | + "execution_count": null, |
| 176 | + "metadata": { |
| 177 | + "id": "9a3295c80eb9" |
157 | 178 | }, |
158 | 179 | "outputs": [], |
159 | 180 | "source": [ |
160 | | - "! pip install --upgrade google-cloud-secret-manager \\\n", |
161 | | - " google-cloud-storage" |
| 181 | + "# Uninstall the previous version.\n", |
| 182 | + "# Run only if version < 0.9.0.\n", |
| 183 | + "# Restart runtime/kernel after uninstalling.\n", |
| 184 | + "# Run from next cell after the kernel restart\n", |
| 185 | + "if jax.__version__ != \"0.9.0\":\n", |
| 186 | + " print(f\"Unistalling {jax.__version__}\")\n", |
| 187 | + " ! pip uninstall -y jax jaxlib jax_cuda12_plugin" |
| 188 | + ] |
| 189 | + }, |
| 190 | + { |
| 191 | + "cell_type": "code", |
| 192 | + "execution_count": null, |
| 193 | + "metadata": { |
| 194 | + "id": "3e0625b2f13d" |
| 195 | + }, |
| 196 | + "outputs": [], |
| 197 | + "source": [ |
| 198 | + "# Install specific vesion.\n", |
| 199 | + "# Run only once after unistalling the jax packages.\n", |
| 200 | + "# Restart the runtime/kernel.\n", |
| 201 | + "# Run from next cell after the kernel restart.\n", |
| 202 | + "!pip install --upgrade jax[cuda12_pip]==0.9.0 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" |
| 203 | + ] |
| 204 | + }, |
| 205 | + { |
| 206 | + "cell_type": "code", |
| 207 | + "execution_count": null, |
| 208 | + "metadata": { |
| 209 | + "id": "b666617f1088" |
| 210 | + }, |
| 211 | + "outputs": [], |
| 212 | + "source": [ |
| 213 | + "# Install specific vesion.\n", |
| 214 | + "# Run only once after the upgrade.\n", |
| 215 | + "# Restart the runtime/kernel.\n", |
| 216 | + "# Run from next cell after the kernel restart.\n", |
| 217 | + "!pip install jax_cuda12_plugin==0.9.0" |
162 | 218 | ] |
163 | 219 | }, |
164 | 220 | { |
|
195 | 251 | "\n", |
196 | 252 | "import haiku as hk\n", |
197 | 253 | "import huggingface_hub\n", |
198 | | - "import jax\n", |
199 | 254 | "import numpy as np\n", |
200 | 255 | "import optax\n", |
201 | 256 | "import orbax.checkpoint as ocp\n", |
|
247 | 302 | { |
248 | 303 | "cell_type": "code", |
249 | 304 | "execution_count": null, |
250 | | - "id": "TDvrVjiCzaYJ", |
251 | 305 | "metadata": { |
252 | | - "id": "534dc48681bd" |
| 306 | + "id": "8f34c4142dd0" |
253 | 307 | }, |
254 | 308 | "outputs": [], |
255 | 309 | "source": [ |
| 310 | + "import jax\n", |
| 311 | + "# we need >0.9.0 jax libs\n", |
256 | 312 | "# check the jax versions\n", |
257 | 313 | "import jaxlib\n", |
258 | 314 | "\n", |
|
809 | 865 | }, |
810 | 866 | "outputs": [], |
811 | 867 | "source": [ |
812 | | - "forward_fn = finetune.get_forward_fn(\n", |
813 | | - " output_metadata, jmp_policy=\"params=float32,compute=float32,output=float32\"\n", |
814 | | - ")\n", |
815 | | - "# forward_fn = finetune.get_forward_fn(output_metadata)\n", |
| 868 | + "forward_fn = finetune.get_forward_fn(output_metadata)\n", |
816 | 869 | "with jax.set_mesh(mesh):\n", |
817 | 870 | " batch = jax.device_put(batch, data_sharding)\n", |
818 | 871 | " params_ft, state_ft = jax.jit(\n", |
|
952 | 1005 | }, |
953 | 1006 | "outputs": [], |
954 | 1007 | "source": [ |
955 | | - "loss = []\n", |
956 | | - "step = 0\n", |
957 | | - "start_time = time.monotonic()\n", |
| 1008 | + "loss, times = [], []\n", |
958 | 1009 | "for step in range(NUM_TRAIN_STEPS):\n", |
| 1010 | + " start_time = time.time()\n", |
959 | 1011 | " try:\n", |
960 | 1012 | " batch = next(ds_iter)\n", |
961 | 1013 | " except StopIteration:\n", |
|
965 | 1017 | " batch = jax.device_put(batch, data_sharding)\n", |
966 | 1018 | " params, state, opt_state, scalars = train_step(params, state, opt_state, batch)\n", |
967 | 1019 | " loss.append(scalars[\"loss\"])\n", |
| 1020 | + " times.append(time.time() - start_time)\n", |
968 | 1021 | " if step % 10 == 1:\n", |
969 | | - " print(\"loss\", step, loss[-1])\n", |
970 | | - "end_time = time.monotonic()\n", |
971 | | - "duration = end_time - start_time\n", |
972 | | - "print(f\"Training took: {duration:.4f} seconds\")\n", |
| 1022 | + " print(\"loss\", step, loss[-1], f\"SPS: {1./np.mean(times[1:]):.4f}\")\n", |
| 1023 | + "\n", |
| 1024 | + "print(f\"Total Training time: {np.sum(times[1:]):.4f} seconds\")\n", |
| 1025 | + "print(f\"Average Training time per step: {np.mean(times[1:]):.4f} seconds\")\n", |
973 | 1026 | "ckpt_path = save((params, state), step + 1)" |
974 | 1027 | ] |
975 | 1028 | }, |
|
0 commit comments