Skip to content

Commit 7750e83

Browse files
authored
fix: inference key change, finetuning jaxlib update (#4447)
1 parent 649800e commit 7750e83

2 files changed

Lines changed: 77 additions & 34 deletions

File tree

notebooks/community/alphagenome/cloudai_alphagenome_finetune.ipynb

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@
106106
"id": "8c8f59a6c426"
107107
},
108108
"source": [
109-
"### Prerequisites\n",
109+
"## 0: Prerequisites\n",
110110
"\n",
111111
"- Install AlphaGenome Research and Google Cloud Platform packages.\n",
112112
"- (*) Choose either H100 or A100 specific vm notebook runtime.\n",
113113
"- (**) Save Huggingface credentials in Google Cloud Secret manager\n",
114+
"- (***) Install 0.9.0 jax libraries\n",
114115
"\n",
115116
"Notebook launch:\n",
116117
"- Launch the Notebook in Google Cloud Enterprise Colab.\n",
@@ -130,7 +131,11 @@
130131
"You will be downloading weights from Huggingface.\n",
131132
"Ensure that:\n",
132133
"- You create a token that has 'Read access to contents of all public gated repos you can access' (under Finegrained control)\n",
133-
"- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0)."
134+
"- You accept the T&C of the [model](https://huggingface.co/google/alphagenome-fold-0).\n",
135+
"\n",
136+
"\n",
137+
"(***):\n",
138+
"Upgrading to 0.9.0 will require multiple runtime restarts."
134139
]
135140
},
136141
{
@@ -151,14 +156,65 @@
151156
{
152157
"cell_type": "code",
153158
"execution_count": null,
154-
"id": "ictFNXdeQ4Cf",
155159
"metadata": {
156-
"id": "2bdbb159823b"
160+
"id": "834dcef76adf"
161+
},
162+
"outputs": [],
163+
"source": [
164+
"import jax\n",
165+
"# We need >0.9.0 jax libs.\n",
166+
"# Check the jax version.\n",
167+
"import jaxlib\n",
168+
"\n",
169+
"print(f\"{jax.__version__=}\")\n",
170+
"print(f\"{jaxlib.__version__=}\")"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": null,
176+
"metadata": {
177+
"id": "9a3295c80eb9"
157178
},
158179
"outputs": [],
159180
"source": [
160-
"! pip install --upgrade google-cloud-secret-manager \\\n",
161-
" google-cloud-storage"
181+
"# Uninstall the previous version.\n",
182+
"# Run only if version < 0.9.0.\n",
183+
"# Restart runtime/kernel after uninstalling.\n",
184+
"# Run from next cell after the kernel restart\n",
185+
"if jax.__version__ != \"0.9.0\":\n",
186+
" print(f\"Unistalling {jax.__version__}\")\n",
187+
" ! pip uninstall -y jax jaxlib jax_cuda12_plugin"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"id": "3e0625b2f13d"
195+
},
196+
"outputs": [],
197+
"source": [
198+
"# Install specific vesion.\n",
199+
"# Run only once after unistalling the jax packages.\n",
200+
"# Restart the runtime/kernel.\n",
201+
"# Run from next cell after the kernel restart.\n",
202+
"!pip install --upgrade jax[cuda12_pip]==0.9.0 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": null,
208+
"metadata": {
209+
"id": "b666617f1088"
210+
},
211+
"outputs": [],
212+
"source": [
213+
"# Install specific vesion.\n",
214+
"# Run only once after the upgrade.\n",
215+
"# Restart the runtime/kernel.\n",
216+
"# Run from next cell after the kernel restart.\n",
217+
"!pip install jax_cuda12_plugin==0.9.0"
162218
]
163219
},
164220
{
@@ -195,7 +251,6 @@
195251
"\n",
196252
"import haiku as hk\n",
197253
"import huggingface_hub\n",
198-
"import jax\n",
199254
"import numpy as np\n",
200255
"import optax\n",
201256
"import orbax.checkpoint as ocp\n",
@@ -247,12 +302,13 @@
247302
{
248303
"cell_type": "code",
249304
"execution_count": null,
250-
"id": "TDvrVjiCzaYJ",
251305
"metadata": {
252-
"id": "534dc48681bd"
306+
"id": "8f34c4142dd0"
253307
},
254308
"outputs": [],
255309
"source": [
310+
"import jax\n",
311+
"# we need >0.9.0 jax libs\n",
256312
"# check the jax versions\n",
257313
"import jaxlib\n",
258314
"\n",
@@ -809,10 +865,7 @@
809865
},
810866
"outputs": [],
811867
"source": [
812-
"forward_fn = finetune.get_forward_fn(\n",
813-
" output_metadata, jmp_policy=\"params=float32,compute=float32,output=float32\"\n",
814-
")\n",
815-
"# forward_fn = finetune.get_forward_fn(output_metadata)\n",
868+
"forward_fn = finetune.get_forward_fn(output_metadata)\n",
816869
"with jax.set_mesh(mesh):\n",
817870
" batch = jax.device_put(batch, data_sharding)\n",
818871
" params_ft, state_ft = jax.jit(\n",
@@ -952,10 +1005,9 @@
9521005
},
9531006
"outputs": [],
9541007
"source": [
955-
"loss = []\n",
956-
"step = 0\n",
957-
"start_time = time.monotonic()\n",
1008+
"loss, times = [], []\n",
9581009
"for step in range(NUM_TRAIN_STEPS):\n",
1010+
" start_time = time.time()\n",
9591011
" try:\n",
9601012
" batch = next(ds_iter)\n",
9611013
" except StopIteration:\n",
@@ -965,11 +1017,12 @@
9651017
" batch = jax.device_put(batch, data_sharding)\n",
9661018
" params, state, opt_state, scalars = train_step(params, state, opt_state, batch)\n",
9671019
" loss.append(scalars[\"loss\"])\n",
1020+
" times.append(time.time() - start_time)\n",
9681021
" if step % 10 == 1:\n",
969-
" print(\"loss\", step, loss[-1])\n",
970-
"end_time = time.monotonic()\n",
971-
"duration = end_time - start_time\n",
972-
"print(f\"Training took: {duration:.4f} seconds\")\n",
1022+
" print(\"loss\", step, loss[-1], f\"SPS: {1./np.mean(times[1:]):.4f}\")\n",
1023+
"\n",
1024+
"print(f\"Total Training time: {np.sum(times[1:]):.4f} seconds\")\n",
1025+
"print(f\"Average Training time per step: {np.mean(times[1:]):.4f} seconds\")\n",
9731026
"ckpt_path = save((params, state), step + 1)"
9741027
]
9751028
},

notebooks/community/alphagenome/cloudai_alphagenome_vai_quickstart.ipynb

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,24 +4458,14 @@
44584458
"outputs": [],
44594459
"source": [
44604460
"def extract_k562(adata):\n",
4461-
" if \"ontologyTerm\" in adata.var.columns and not adata.var[\"ontologyTerm\"].empty:\n",
4462-
" mask = adata.var[\"ontologyTerm\"].apply(\n",
4463-
" lambda x: isinstance(x, dict)\n",
4464-
" and x.get(\"ontologyType\") == \"ONTOLOGY_TYPE_EFO\"\n",
4465-
" and x.get(\"id\") == \"2067\"\n",
4466-
" )\n",
4467-
" values = adata.X[:, mask]\n",
4468-
" else:\n",
4469-
" raise ValueError(\n",
4470-
" \"Expected 'ontologyTerm' column with dictionary values not found in adata.var\"\n",
4471-
" )\n",
4472-
" assert values.size == 1\n",
4473-
" return values.flatten()[0]\n",
4461+
" values = adata.X[:, adata.var['ontology_curie'] == 'EFO:0002067']\n",
4462+
" assert values.size == 1\n",
4463+
" return values.flatten()[0]\n",
44744464
"\n",
44754465
"\n",
44764466
"ism_result = ism.ism_matrix(\n",
44774467
" [extract_k562(x[0]) for x in variant_scores],\n",
4478-
" variants=[v[0].uns[\"variant\"] for v in variant_scores],\n",
4468+
" variants=[v[0].uns['variant'] for v in variant_scores],\n",
44794469
")"
44804470
]
44814471
},

0 commit comments

Comments
 (0)