Skip to content

[JAX] Improve JAX tutorial documentation#2976

Open
jberchtold-nvidia wants to merge 6 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial
Open

[JAX] Improve JAX tutorial documentation#2976
jberchtold-nvidia wants to merge 6 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 11, 2026

Description

Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.

Additionally, this switches from notebook .ipynb files to .rst and separate .py files for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Rework existing tutorial and replace with new Dense-specific tutorial
  • Placeholders for Attention and MoE
  • Refactor .ipynb notebooks to .rst and .py files for similar appearance in docs but better testability in CI by running .py files

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR replaces the JAX integration .ipynb notebook with a set of .rst + .py files to improve testability in CI and refocuses the tutorial around individual operations (Dense, Attention, MoE) so readers can experiment with one optimization at a time.

  • Hub page (te_jax_integration.rst) and Dense tutorial (dense.rst + dense.py) replace the deleted te_jax_integration.ipynb; Attention and MoE pages are added as placeholder stubs.
  • CI wiring: both L0_jax_unittest/test.sh and L1_jax_distributed_unittest/test.sh now exercise docs/examples/jax_examples/ via pytest, with multi-GPU tests auto-skipping on single-GPU runners.
  • conftest.py adds docs/examples/ to sys.path correctly via __file__-relative paths, but dense.py also contains a sys.path.append(\"..\") that resolves incorrectly when the script is run from the repo root (as the regeneration comment in dense.out instructs), causing import quickstart_jax_utils to fail in standalone mode.

Confidence Score: 4/5

Safe to merge once the sys.path issue in dense.py is addressed; CI pytest runs work correctly due to conftest.py, but the documented standalone regeneration command will fail.

The sys.path.append("..") in dense.py is resolved relative to the current working directory, so running python3 docs/examples/jax_examples/dense.py from the repo root (as dense.out instructs) will fail to locate quickstart_jax_utils. The pytest path in CI is unaffected because conftest.py inserts the correct absolute path first, masking the problem during automated testing.

docs/examples/jax_examples/dense.py — the sys.path manipulation needs to mirror the absolute-path approach used in conftest.py.

Important Files Changed

Filename Overview
docs/examples/jax_examples/dense.py New tutorial Python file; contains a fragile sys.path.append("..") that breaks standalone execution from the repo root even though conftest.py correctly uses absolute paths for pytest runs.
docs/examples/jax_examples/dense.rst New RST tutorial for Dense GEMMs; section numbering jumps from §3 to §6 with no §4/§5 stubs, which will confuse readers expecting sequential headings.
docs/examples/jax_examples/conftest.py New pytest conftest that correctly inserts docs/examples/ into sys.path using absolute paths derived from __file__; no issues.
docs/examples/te_jax_integration.rst New RST hub page replacing the deleted .ipynb; uses RST list-table directives correctly; placeholder rows for Attention and MoE have empty "Covers" cells which are valid in RST tables.
qa/L0_jax_unittest/test.sh Adds single-GPU pytest run of docs/examples/jax_examples/ to L0 CI; multi-GPU tests will auto-skip on single-GPU runners.
qa/L1_jax_distributed_unittest/test.sh Adds -k multi_gpu pytest run to L1 distributed CI; consistent with the skip condition in the test itself.
docs/examples/jax_examples/dense.out Captured benchmark output from a GB200; the documented regeneration command will fail unless the sys.path issue in dense.py is fixed first.
docs/examples/jax_examples/attention.rst Placeholder stub for the upcoming Attention tutorial; no functional content yet.
docs/examples/jax_examples/moe.rst Placeholder stub for the upcoming MoE tutorial; no functional content yet.
docs/index.rst Single-line change updating the toctree entry from .ipynb to .rst; correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[docs/index.rst] --> B[te_jax_integration.rst\nhub page]
    B --> C[jax_examples/dense.rst\nDense GEMM tutorial]
    B --> D[jax_examples/attention.rst\nComing soon]
    B --> E[jax_examples/moe.rst\nComing soon]
    C -->|literalinclude markers| F[dense.py\nrunnable source]
    C -->|literalinclude| G[dense.out\ncaptured output]
    F -->|pytest via conftest.py| H[conftest.py\nfixes sys.path for pytest]
    F -->|standalone run| I[sys.path.append dotdot\nbreaks from repo root]
    H --> J[L0 CI: single-GPU tests]
    H --> K[L1 CI: multi-GPU tests]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/impr..." | Re-trigger Greptile

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
@@ -0,0 +1,446 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing sections 4 and 5 — numbering jumps from § 3 to § 6

The notebook's section headings go ## 1## 2## 3. Single-GPU performance## 6. Multi-GPU## 7. Collective GEMM (placeholder), with no trace of sections 4 or 5. Unlike § 7, there are no "Coming soon" placeholders for them either. A reader following the numbered flow will assume content was accidentally deleted. If these sections are planned but not yet written, add stub cells similar to the ## 7 placeholder; if the numbering was simply mis-applied, renumber the existing headings to be consecutive (3 → 4 for Multi-GPU, 4 → 5 for Collective GEMM).

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
@@ -0,0 +1,446 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused warnings import

import warnings is imported in the setup cell but never referenced anywhere in the notebook. Remove it to keep the imports clean.

"\n",
"**TODO — Coming soon.**\n",
"\n",
"[← Back to the JAX integration overview](../te_jax_integration.ipynb)"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
"source": [
"## 7. Collective GEMM (placeholder)\n",
"\n",
"*Coming soon.*"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have feedback on where it fits w.r.t the rest of the tutorial

Comment thread docs/examples/jax_examples/moe.ipynb Outdated
"\n",
"**TODO — Coming soon.**\n",
"\n",
"This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
{
"cell_type": "markdown",
"id": "intro-md",
"metadata": {},
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworking the existing getting started tutorials that are merged with PyTorch tutorials will be a follow-up PR

model_apply_fn=te_model.apply,
variables=te_vars,
input=x,
output_grad=dy,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 sys.path.append("..") breaks standalone execution from the repo root

dense.out documents that the canonical way to regenerate captured output is python3 docs/examples/jax_examples/dense.py > dense.out, run from the repo root. At that point ".." resolves to the parent of the repo root, not to docs/examples/ where quickstart_jax_utils lives, so import quickstart_jax_utils as utils raises ModuleNotFoundError. In pytest mode conftest.py inserts the correct absolute path before the module is imported, masking the problem in CI, but the standalone invocation (and the snippet shown to users in the tutorial) breaks. conftest.py already shows the right pattern — use os.path.dirname(os.path.abspath(__file__)) to construct an absolute path instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant