[JAX] Improve JAX tutorial documentation#2976
Conversation
Greptile SummaryThis PR replaces the JAX integration
Confidence Score: 4/5Safe to merge once the sys.path issue in dense.py is addressed; CI pytest runs work correctly due to conftest.py, but the documented standalone regeneration command will fail. The docs/examples/jax_examples/dense.py — the sys.path manipulation needs to mirror the absolute-path approach used in conftest.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[docs/index.rst] --> B[te_jax_integration.rst\nhub page]
B --> C[jax_examples/dense.rst\nDense GEMM tutorial]
B --> D[jax_examples/attention.rst\nComing soon]
B --> E[jax_examples/moe.rst\nComing soon]
C -->|literalinclude markers| F[dense.py\nrunnable source]
C -->|literalinclude| G[dense.out\ncaptured output]
F -->|pytest via conftest.py| H[conftest.py\nfixes sys.path for pytest]
F -->|standalone run| I[sys.path.append dotdot\nbreaks from repo root]
H --> J[L0 CI: single-GPU tests]
H --> K[L1 CI: multi-GPU tests]
Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/impr..." | Re-trigger Greptile |
| @@ -0,0 +1,446 @@ | |||
| { | |||
There was a problem hiding this comment.
Missing sections 4 and 5 — numbering jumps from § 3 to § 6
The notebook's section headings go ## 1 → ## 2 → ## 3. Single-GPU performance → ## 6. Multi-GPU → ## 7. Collective GEMM (placeholder), with no trace of sections 4 or 5. Unlike § 7, there are no "Coming soon" placeholders for them either. A reader following the numbered flow will assume content was accidentally deleted. If these sections are planned but not yet written, add stub cells similar to the ## 7 placeholder; if the numbering was simply mis-applied, renumber the existing headings to be consecutive (3 → 4 for Multi-GPU, 4 → 5 for Collective GEMM).
| @@ -0,0 +1,446 @@ | |||
| { | |||
| "\n", | ||
| "**TODO — Coming soon.**\n", | ||
| "\n", | ||
| "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized
| "source": [ | ||
| "## 7. Collective GEMM (placeholder)\n", | ||
| "\n", | ||
| "*Coming soon.*" |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have feedback on where it fits w.r.t the rest of the tutorial
| "\n", | ||
| "**TODO — Coming soon.**\n", | ||
| "\n", | ||
| "This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n", |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized
| { | ||
| "cell_type": "markdown", | ||
| "id": "intro-md", | ||
| "metadata": {}, |
There was a problem hiding this comment.
Reworking the existing getting started tutorials that are merged with PyTorch tutorials will be a follow-up PR
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
| model_apply_fn=te_model.apply, | ||
| variables=te_vars, | ||
| input=x, | ||
| output_grad=dy, |
There was a problem hiding this comment.
sys.path.append("..") breaks standalone execution from the repo root
dense.out documents that the canonical way to regenerate captured output is python3 docs/examples/jax_examples/dense.py > dense.out, run from the repo root. At that point ".." resolves to the parent of the repo root, not to docs/examples/ where quickstart_jax_utils lives, so import quickstart_jax_utils as utils raises ModuleNotFoundError. In pytest mode conftest.py inserts the correct absolute path before the module is imported, masking the problem in CI, but the standalone invocation (and the snippet shown to users in the tutorial) breaks. conftest.py already shows the right pattern — use os.path.dirname(os.path.abspath(__file__)) to construct an absolute path instead.
Description
Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.
Additionally, this switches from notebook
.ipynbfiles to.rstand separate.pyfiles for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.Type of change
Changes
Checklist: