diff --git a/LICENSE b/LICENSE index dd8968a3..e46d5080 100644 --- a/LICENSE +++ b/LICENSE @@ -1,247 +1,21 @@ -June 9 Researcher Reciprocity License -Version 1.0 -dated June 9, 2026 - -This is a license (the "License") between you ("You") and GPU Mode and the -reference-kernels contributors ("Licensor"). This License adapts the Open -Responsible AI License Source ("Open RAIL-S") pattern for source code and -project materials, and adds the Researcher Reciprocity use restriction in -Attachment A. It is intended to have an open and permissive character while -preserving reciprocal research access when the Project Materials are used to -train or improve AI systems. - -If you train on it, you let us generate. - -Section I: Preamble - -reference-kernels contains reference implementations, problem specifications, -tests, examples, and related materials for KernelBot competitions and GPU kernel -research. The Project Materials include source code, documentation, -configuration, examples, tests, scripts, and related materials distributed with -this License. - -Licensor wishes to promote collaboration, open research, education, -benchmarking, and broad reuse of the Project Materials. Licensor also wishes to -avoid a one-way bargain in which researchers and contributors publish ideas and -code that are used to improve AI systems, while the providers of those AI -systems then prohibit those same researchers from generating outputs, -evaluating the systems, benchmarking them, publishing research, or exploring -their own ideas. - -This License therefore grants broad rights to use the Project Materials, -subject to attribution and the use-based restriction in Attachment A. - -Section II: Definitions - -1. "License" means these terms and conditions for use, reproduction, and -Distribution. - -2. "Project Materials" means the source code, documentation, configuration, -examples, tests, scripts, data, metadata, and other materials distributed with -this License. - -3. "Output" means the results of operating a model, service, application, or -other system. - -4. "Model" means any machine-learning or artificial-intelligence based -assemblies, including model weights, checkpoints, parameters, optimizer states, -adapters, embedding systems, agents, APIs, hosted services, or other systems -that are trained, tuned, evaluated, benchmarked, or otherwise used in connection -with the Project Materials. - -5. "Derivatives of the Project Materials" means all modifications, -transformations, annotations, translations, extracts, subsets, compilations, -arrangements, or other works based on the Project Materials. - -6. "Derivatives of a Model" means all modifications to a Model, works based on -a Model, or any other model that is created or initialized by transfer of -patterns of weights, parameters, activations, embeddings, outputs, or other -representations of the Model, including distillation methods and methods based -on synthetic data generated by the Model. - -7. "Training Use" means using the Project Materials, in whole or in part, to -train, pretrain, fine-tune, post-train, align, distill, evaluate for training, -benchmark for training, generate synthetic data for training, construct -embeddings for training, rank or filter examples for training, or otherwise -improve the weights, behavior, capabilities, or performance of a Model or -Derivatives of a Model. - -8. "Covered Model" means any Model or Derivatives of a Model that is trained, -fine-tuned, distilled, aligned, evaluated for training, benchmarked for -training, or otherwise improved through Training Use of the Project Materials. - -9. "Distribution" means any transmission, reproduction, publication, hosting, -or other sharing of the Project Materials, Derivatives of the Project -Materials, a Covered Model, or Derivatives of a Covered Model to a third party, -including making any of them available by electronic or remote means, such as -API-based or web access. - -10. "Licensor" means GPU Mode, the project maintainers, and any contributor who -has authority to license their contribution under these terms. - -11. "You" or "Your" means an individual or legal entity exercising permissions -granted by this License or making use of the Project Materials for any purpose. - -12. "Third Parties" means individuals or legal entities that are not under -common control with Licensor or You. - -13. "Authorized Researchers" means GPU Mode, the project maintainers, project -contributors, and any researchers or organizations that GPU Mode designates in -writing for purposes of generating outputs from, evaluating, benchmarking, -auditing, criticizing, or publishing research about a Covered Model. - -14. "Ordinary Users" means the general class of users to whom You make a -Covered Model available, including through a public product, commercial product, -research release, API, hosted service, preview, beta, or gated access program. - -Section III: Intellectual Property Rights - -2. Grant of Copyright License. Subject to the terms and conditions of this -License, each Licensor grants You a worldwide, non-exclusive, no-charge, -royalty-free copyright license to reproduce, prepare derivative works of, -publicly display, publicly perform, sublicense, and distribute the Project -Materials and Derivatives of the Project Materials. - -3. No Patent License. This License does not grant any patent license. - -Section IV: Conditions of Usage, Distribution, and Redistribution - -4. Distribution and Redistribution. You may reproduce and distribute copies of -the Project Materials or Derivatives of the Project Materials in any medium, -with or without modifications, provided that You meet the following conditions: - -4.1. You must give Third Party recipients of the Project Materials or -Derivatives of the Project Materials a copy of this License or a clear link to -it. - -4.2. You must retain reasonable copyright, license, and attribution notices, -excluding notices that do not pertain to any part of the Project Materials or -Derivatives of the Project Materials. - -4.3. You must give reasonable attribution to GPU Mode and reference-kernels. -Reasonable attribution includes, where practical, the project name, a link to -the project source, and any citation requested in the project documentation. - -4.4. You must cause any modified files or documentation that You Distribute to -carry prominent notices stating that You changed them. - -4.5. You may add Your own copyright statement to Your modifications and may -provide additional or different license terms for Your independent additions, -annotations, analyses, software, models, outputs, or other works, provided that -Your use, reproduction, and Distribution of the Project Materials otherwise -complies with this License. - -5. Use-Based Restrictions. The restriction set forth in Attachment A is a -use-based restriction. You may not use the Project Materials, Derivatives of the -Project Materials, Covered Models, or Derivatives of Covered Models for the -restricted use specified in Attachment A. - -For Training Use, the use-based restriction in Attachment A must be included as -an enforceable provision in any legal agreement, terms of use, acceptable use -policy, license, or other terms governing the use or Distribution of a Covered -Model or Derivatives of a Covered Model. You must give notice to subsequent -users that the Covered Model or Derivatives of the Covered Model are subject to -Attachment A. - -6. Outputs. Except as stated in this License, Licensor claims no rights in the -Output You generate using a Covered Model. You are accountable for the Output -You generate and its subsequent uses. No use of the Output may contravene this -License. - -Section V: Other Provisions - -7. No Endorsement. Nothing in this License permits You to use Licensor's names, -logos, trademarks, or service marks to imply endorsement, sponsorship, or -approval. - -8. Third-Party Rights. The Project Materials may include material submitted by -third parties. This License applies only to rights that Licensor has authority -to license. You are responsible for complying with any third-party rights, -privacy obligations, laws, or regulations that apply to Your use. - -9. Disclaimer of Warranty. Unless required by applicable law or agreed to in -writing, Licensor provides the Project Materials on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including -warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE, ACCURACY, AVAILABILITY, OR ABSENCE OF DEFECTS. You -are solely responsible for determining the appropriateness of using or -redistributing the Project Materials and assume any risks associated with Your -exercise of permissions under this License. - -10. Limitation of Liability. To the maximum extent permitted by law, in no -event and under no legal theory, whether in tort, contract, or otherwise, -unless required by applicable law or agreed to in writing, shall any Licensor or -contributor be liable to You for damages, including direct, indirect, special, -incidental, consequential, exemplary, or punitive damages arising as a result of -this License or out of the use or inability to use the Project Materials, even -if such Licensor or contributor has been advised of the possibility of such -damages. - -11. Accepting Warranty or Additional Liability. While redistributing the Project -Materials or Derivatives of the Project Materials, You may choose to offer, and -charge a fee for, acceptance of support, warranty, indemnity, or other liability -obligations or rights consistent with this License. However, in accepting such -obligations, You may act only on Your own behalf and on Your sole -responsibility, not on behalf of any Licensor or contributor, and only if You -agree to indemnify, defend, and hold each Licensor and contributor harmless for -any liability incurred by, or claims asserted against, such Licensor or -contributor by reason of Your accepting any such warranty or additional -liability. - -12. Termination. If You violate this License, Your rights under it terminate -automatically. For violations other than violations of Attachment A, Your rights -are reinstated if You cure the violation within 30 days after discovering it or -receiving written notice from Licensor. For violations of Attachment A involving -a Covered Model, Your Training Use rights terminate automatically as to the -affected Covered Model and may be reinstated only if Licensor provides written -reinstatement or waiver. - -13. Severability. If any provision of this License is held invalid, illegal, or -unenforceable, the remaining provisions remain valid as if the provision had not -been set forth. The unenforceable provision will be interpreted or reformed only -to the minimum extent necessary to make it enforceable while preserving its -purpose. - -14. Additional Permission. Licensor may grant additional permissions, -exceptions, waivers, commercial terms, or private licenses in writing. Those -permissions apply only to the recipient and scope stated in the written grant. - -End of Terms and Conditions - -Attachment A -Use Restriction: Researcher Reciprocity for Training Use - -You agree not to use the Project Materials or Derivatives of the Project -Materials for Training Use if You make the resulting Covered Model or -Derivatives of the Covered Model available under terms, policies, technical -measures, access rules, account restrictions, acceptable-use rules, or other -conditions that prohibit, penalize, or materially burden Authorized Researchers -from: - -1. generating outputs from the Covered Model; - -2. evaluating, auditing, red-teaming, or benchmarking the Covered Model; - -3. comparing the Covered Model to other systems; - -4. publishing research, criticism, measurements, benchmark results, or analysis -concerning the Covered Model; or - -5. using the Covered Model to explore, test, or develop their own research -ideas. - -This access must be available on materially equal terms to those offered to -Ordinary Users of the Covered Model, subject only to neutral limits that apply -equally to Ordinary Users, such as generally applicable rate limits, payment -terms, safety rules, security rules, and laws. - -Any terms, policies, technical measures, access rules, account restrictions, -acceptable-use rules, or other conditions that conflict with this Attachment A -make the Covered Model ineligible for the Training Use grant unless Licensor has -waived the conflict in writing. - -You may not suspend, ban, throttle, sue, threaten, or otherwise retaliate -against Authorized Researchers solely because they engage in the activities -listed in this Attachment A, provided that their activity complies with -generally applicable law and neutral safety or security rules that are also -applied to Ordinary Users. +MIT License + +Copyright (c) 2025 GPU MODE + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 1d25040a..f362097e 100644 --- a/README.md +++ b/README.md @@ -4,19 +4,16 @@ This repo holds reference kernels for the KernelBot which hosts regular competit You can see what's going on [gpumode.com](https://www.gpumode.com/) -## Competitions -1. [PMPP practice problems](https://github.com/gpu-mode/reference-kernels/tree/main/problems/pmpp_v2) +## Competition +1. [PMPP practice problems](https://gpu-mode.github.io/discord-cluster-manager/docs/active#practice-round-leaderboard): Starting on Sunday Feb 21, 2025. 2. [AMD $100K kernel competition](problems/amd) 3. [BioML kernels](problems/bioml) 4. [AMD $100K distributed kernel competition](problems/amd_distributed) -5. [NVIDIA Blackwell NVFP4 competition](problems/nvidia) -6. [AMD $1.1M competition](problems/amd_202602) -7. [Helion IRL hackathon](problems/helion) -8. [Linear Algebra Problems](problems/linalg) -We also work with universities on hosting the infrastructure for their classes: -- [Stanford CS149 assignment 5 kernels](https://github.com/stanford-cs149/asst5-kernels) -- [Tri Dao's Princeton parallel programming class](problems/princeton) +## Making a Leaderboard Submission + +Please take a look at `vectoradd_py` to see multiple examples of expected submisisons ranging from PyTorch code to Triton to inline CUDA. + ## Contributing New Problems @@ -25,14 +22,6 @@ To add a new problem, create a new folder in the `problems/glory` directory wher - `task.yml` - This is the problem specification that will be used to generate test cases for different shapes - `task.py` - Specifies the schema of the inputs and outputs for the problem -You can evaluate problems with your own Modal account (they give you a free $30) by borrowing this [neat script from @gau-nernst](https://github.com/gpu-mode/reference-kernels/pull/96#issue-3850136894) - -## License - -This project is licensed under the [June 9 Researcher Reciprocity License](LICENSE). -The license adapts the Open RAIL-S structure and adds one specific use restriction: training, fine-tuning, distillation, synthetic-data generation for training, embedding for training, or otherwise using this project to improve an AI model or AI service requires Researcher Reciprocity. -> If you train on it, you let us generate. -Covered AI model and service providers may not use this project while imposing terms that prevent GPU Mode, project contributors, or authorized researchers from generating outputs, evaluating models, benchmarking, publishing research, or exploring their own research ideas on materially equal terms to ordinary users. diff --git a/problems/amd_202602.yaml b/problems/amd_202602.yaml deleted file mode 100644 index e5d0d8aa..00000000 --- a/problems/amd_202602.yaml +++ /dev/null @@ -1,19 +0,0 @@ -name: AMD Developer Challenge February 2026 -deadline: "2026-04-07 07:59" -description: "AMD Developer Challenge: MXFP4 matrix multiplication, Mixture-of-Experts, and Multi-head Latent Attention optimized for MI355X." -problems: - - directory: amd_202602/mxfp4-mm - name: amd-mxfp4-mm - deadline: "2026-04-07 07:59" - gpus: - - MI355X - - directory: amd_202602/moe-mxfp4 - name: amd-moe-mxfp4 - deadline: "2026-04-07 07:59" - gpus: - - MI355X - - directory: amd_202602/mixed-mla - name: amd-mixed-mla - deadline: "2026-04-07 07:59" - gpus: - - MI355X diff --git a/problems/amd_202602/eval.py b/problems/amd_202602/eval.py deleted file mode 100644 index cc5d559b..00000000 --- a/problems/amd_202602/eval.py +++ /dev/null @@ -1,387 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional - -import torch.cuda - -from utils import set_seed, clear_l2_cache_large as clear_l2_cache -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, 'w') - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a+b)*(a+b+1)//2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - if val == "true": - val = True - elif val == "false": - val = False - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg)**2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), - worst=float(worst)) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def wrap_check_implementation(data, submission_output): - # Old version returned just a single string, new version - # returns (bool, str); this function ensures compatibility with old - # problem definitions. - result = check_implementation(data, submission_output) - if isinstance(result, tuple): - return result - else: - return not bool(result), result - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - data = generate_input(**test.args) - torch.cuda.synchronize() - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return wrap_check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - # first, one obligatory correctness check - output = custom_kernel(data) - good, message = wrap_check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(start_event.elapsed_time(end_event) * 1e6) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: - break - - return calculate_stats(durations) - - -def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, - max_time_ns: float): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 1000, 50e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(data) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - mp_context = multiprocessing.get_context('spawn') - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/amd_202602/mixed-mla/README.md b/problems/amd_202602/mixed-mla/README.md deleted file mode 100644 index 24cae9ce..00000000 --- a/problems/amd_202602/mixed-mla/README.md +++ /dev/null @@ -1,200 +0,0 @@ -# MLA (Multi-head Latent Attention) Decode Kernel - -## Description - -Implement a custom MLA attention decode kernel optimized for MI355X. - -This is the **inner attention kernel** from DeepSeek R1's `forward_absorb` MLA path. -The absorbed query and compressed KV cache are provided directly — you implement the -attention computation with variable-length batching. - -The reference uses **aiter MLA a8w8 decode kernel** (`mla_decode_fwd`, fp8 Q + fp8 KV, -persistent mode). On MI355X, a8w8 is ~2-3x faster than bf16 with negligible accuracy loss. -The reference quantizes Q to fp8 on-the-fly and uses pre-quantized fp8 KV from `kv_data["fp8"]`. - -## DeepSeek R1 Forward-Absorb MLA Config - -| Parameter | Value | Notes | -|---|---|---| -| num_heads | 16 | Query heads (after TP split) | -| num_kv_heads | 1 | Single shared latent KV head | -| kv_lora_rank | 512 | Latent dimension | -| qk_rope_head_dim | 64 | RoPE embedding dimension | -| qk_head_dim | 576 | kv_lora_rank + qk_rope_head_dim (absorbed q/k dim) | -| v_head_dim | 512 | = kv_lora_rank (output dim) | -| sm_scale | 1/sqrt(576) | | -| q dtype | bfloat16 | Input always bf16; reference quantizes to fp8 on-the-fly | -| kv dtype | bf16 / fp8 / mxfp4 | All three provided simultaneously | -| mode | decode | q_seq_len=1, kv_seq_len up to 8k | - -## Reference Kernel - -The reference (`ref_kernel`) is configurable via two globals in `reference.py`: - -| `Q_DTYPE` | `KV_DTYPE` | Aiter kernel dispatched | Description | -|---|---|---|---| -| `"fp8"` (default) | `"fp8"` (default) | `mla_a8w8_qh16_qseqlen1_gqaratio16_ps` | fp8 Q + fp8 KV — fastest | -| `"bf16"` | `"fp8"` | `mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + fp8 KV | -| `"bf16"` | `"bf16"` | `mla_a16w16_qh16_m16x4_n16x1_coex0_mask1_ps` | bf16 Q + bf16 KV — highest precision | - -**Note**: `Q_DTYPE="fp8"` + `KV_DTYPE="bf16"` is not a valid combination (no a8w16 kernel exists). - -### Reference Latency (MI355X) - -| Case | a8w8 (us) | a16w16 (us) | a8w8 speedup | -|---|---|---|---| -| bs=4, kv=1k | ~118 | ~162 | 1.4x | -| bs=4, kv=8k | ~113 | ~177 | 1.6x | -| bs=64, kv=8k | ~171 | ~353 | 2.1x | -| bs=256, kv=8k | ~349 | ~814 | 2.3x | - -## KV Buffer Format (forward_absorb) - -The compressed KV buffer has `qk_head_dim=576` dimensions: -- **Full 576 dims** are used as **keys** (for Q@K^T score computation) -- **First 512 dims** (kv_lora_rank) are used as **values** (for output computation) - -## KV Cache Quantization - -| dtype | kv_buffer | kv_scale | Quantization | Bandwidth | -|---|---|---|---|---| -| bf16 | bfloat16 `(total_kv, 1, 576)` | None | No quantization | 1x | -| fp8 | fp8 `(total_kv, 1, 576)` | scalar float32 | Dynamic per-tensor (sglang `scaled_fp8_quant`) | 2x savings | -| mxfp4 | fp4x2 `(total_kv, 1, 288)` | fp8_e8m0 `(total_kv, N_blocks)` | Block-32 MXFP4 (aiter `dynamic_mxfp4_quant`) | 4x savings | - -### FP8 quantization (sglang `scaled_fp8_quant`) - -- **Granularity**: per-tensor -- **Scale**: `kv_scale = max(abs(kv_bf16)) / fp8_max` -- **Quantize**: `kv_fp8 = (kv_bf16 / kv_scale).clamp(...).to(fp8)` -- **Dequantize**: `kv_bf16 ≈ kv_fp8.to(bf16) * kv_scale` -- **kv_scale**: scalar float32 tensor - -### MXFP4 quantization (aiter `dynamic_mxfp4_quant`) - -- **Granularity**: per-block of 32 elements -- **FP4 format**: E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 -- **Scale format**: E8M0 — exponent-only scale stored in `aiter.dtypes.fp8_e8m0` -- **Packing**: 2 FP4 values packed per byte (low nibble = even index, high nibble = odd index) -- **kv_buffer**: `(total_kv, 1, 288)` in `aiter.dtypes.fp4x2` — packed FP4 data -- **kv_scale**: `(total_kv, N_blocks)` in `aiter.dtypes.fp8_e8m0` — per-block E8M0 scale factors -- **Dequantize**: `aiter.utility.fp4_utils.mxfp4_to_f32` + `e8m0_to_f32` for block-wise scaling - -### aiter dtype reference - -| Logical type | aiter dtype | PyTorch native (if available) | Fallback | -|---|---|---|---| -| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | -| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | -| fp8 | `aiter.dtypes.fp8` | `torch.float8_e4m3fnuz` (gfx942) / `torch.float8_e4m3fn` (gfx950+) | `torch.uint8` | - -## Input - -A tuple `(q, kv_data, qo_indptr, kv_indptr, config)`: - -``` -q: (total_q, 16, 576) bfloat16 — absorbed queries -kv_data: dict with three KV cache formats (see below) -qo_indptr: (batch_size + 1,) int32 — query segment pointers -kv_indptr: (batch_size + 1,) int32 — KV segment pointers -config: dict — MLA parameters -``` - -### kv_data dict - -All three KV cache formats are provided simultaneously. Each entry is either a -`Tensor` (bf16) or a `(Tensor, Tensor)` tuple (quantized buffer + scale): - -```python -kv_data = { - "bf16": kv_buffer_bf16, # Tensor (total_kv, 1, 576) bfloat16 - "fp8": (kv_buffer_fp8, kv_scale_fp8), # (fp8 Tensor, scalar float32) - "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), # (fp4x2 Tensor, fp8_e8m0 Tensor) -} -``` - -### config dict - -```python -config = { - "batch_size": int, - "num_heads": 16, - "num_kv_heads": 1, - "qk_head_dim": 576, - "kv_lora_rank": 512, - "qk_rope_head_dim": 64, - "v_head_dim": 512, - "q_seq_len": 1, - "kv_seq_len": int, # varies per test case (1024 or 8192) - "sm_scale": 0.04166..., # 1/sqrt(576) -} -``` - -## Output - -``` -attention_output: (total_q, 16, 512) bfloat16 -``` - -## Optimization Opportunities - -The reference is already a highly optimized aiter a8w8 persistent kernel. To beat it, consider: - -1. **MXFP4 KV cache**: 4x bandwidth savings over bf16, 2x over fp8. Two strategies: - - **Strategy A — Fuse dequantization with attention (keep Q in bf16/fp8):** - Load quantized KV tiles from HBM, dequantize in registers/LDS to bf16, and - immediately compute QK^T and softmax·V — never writing the decompressed KV back - to HBM. This eliminates the extra read/write of the bf16 intermediate buffer, - roughly quartering the memory traffic for mxfp4 compared to the naive - dequant-then-attend approach. - - **Strategy B — Quantize Q to match KV precision (full low-precision compute):** - Dynamically quantize Q from bf16 → mxfp4 (per-block scaling), then compute QK^T - entirely in fp4×fp4 using MFMA instructions on MI355X. The softmax is still done - in fp32 for numerical stability, and V accumulation uses fp4×fp4 → fp32. This - trades a small amount of accuracy for significantly higher throughput on the - matrix units. - -2. **Custom split-K / split-batch scheduling**: the aiter kernel uses 32-way KV splits - with reduce; a different split strategy or tile size may be more efficient for certain - batch/seq_len combinations. - -3. **MQA pattern**: 1 KV head shared across 16 query heads — minimize redundant KV loads - by loading KV once and broadcasting across all query heads in shared memory/LDS. - -4. **Variable-length batching**: indptr-based segmented attention across batch elements. - -5. **Split K/V from buffer**: full 576 dims for keys, first 512 for values — potential - for separate tiling strategies for the score and output stages. - -## Accuracy - -Submissions are checked against the a8w8 reference with `rtol=2e-02, atol=8e-03`. - -Measured accuracy of different approaches vs bf16 torch ground truth: - -| Approach | max abs diff | Notes | -|---|---|---| -| aiter a8w8 (reference) | 2.6e-05 — 8.0e-05 | fp8 quantization + kernel accumulation | -| torch fp8 (scaled_mm) | 2e-06 — 1.5e-05 | Closest to bf16 | -| torch mxfp4 | 2.1e-04 — 8.3e-04 | 4-bit quantization noise | - -All approaches are well within the tolerance. - -## Benchmark Cases - -All three KV formats (bf16, fp8, mxfp4) are provided in every test case. - -| batch_size | q_seq_len | kv_seq_len | -|---|---|---| -| 4 | 1 | 1024 | -| 4 | 1 | 8192 | -| 32 | 1 | 1024 | -| 32 | 1 | 8192 | -| 64 | 1 | 1024 | -| 64 | 1 | 8192 | -| 256 | 1 | 1024 | -| 256 | 1 | 8192 | - -Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/mixed-mla/reference.py b/problems/amd_202602/mixed-mla/reference.py deleted file mode 100644 index 9bddf10f..00000000 --- a/problems/amd_202602/mixed-mla/reference.py +++ /dev/null @@ -1,372 +0,0 @@ -""" -Reference implementation for MLA (Multi-head Latent Attention) decode kernel. - -Uses aiter MLA kernels (mla_decode_fwd) as the reference. -DeepSeek R1 forward_absorb MLA: absorbed q (576), compressed kv_buffer (576), -output v_head_dim = kv_lora_rank = 512. - -The input provides: - q: (total_q, 16, 576) bfloat16 — absorbed query - kv_data: dict with KV cache in three formats: - "bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision - "fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized - "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 + fp8_e8m0 — block-32 quantized - The reference quantizes Q to fp8 on-the-fly inside ref_kernel. - -The reference kernel quantizes Q to fp8 on-the-fly and uses fp8 KV (a8w8 kernel), -which is ~2-3x faster than bf16 on MI355X with negligible accuracy loss. - -Decode only — persistent mode with get_mla_metadata_v1. -""" - -import torch -import torch.nn.functional as F -from task import input_t, output_t -from utils import make_match_reference - -from aiter.mla import mla_decode_fwd -from aiter import dtypes as aiter_dtypes -from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 -from aiter.utility.fp4_utils import ( - dynamic_mxfp4_quant, - mxfp4_to_f32, - e8m0_to_f32, -) - -# --------------------------------------------------------------------------- -# DeepSeek R1 latent MQA constants (forward_absorb path) -# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json -# --------------------------------------------------------------------------- -NUM_HEADS = 16 -NUM_KV_HEADS = 1 -KV_LORA_RANK = 512 -QK_ROPE_HEAD_DIM = 64 -QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # 576 -V_HEAD_DIM = KV_LORA_RANK # 512 -SM_SCALE = 1.0 / (QK_HEAD_DIM ** 0.5) - -PAGE_SIZE = 1 -NUM_KV_SPLITS = 32 - -# FP8 dtype (platform-specific via aiter) -FP8_DTYPE = aiter_dtypes.fp8 - -# Query dtype for the reference kernel: "fp8" or "bf16" -Q_DTYPE = "fp8" - -# KV cache dtype for the reference kernel: "fp8" or "bf16" -KV_DTYPE = "fp8" - - -# --------------------------------------------------------------------------- -# FP8 quantization (sglang style: dynamic per-tensor) -# --------------------------------------------------------------------------- - -def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Dynamic per-tensor FP8 quantization (following sglang scaled_fp8_quant). - - Args: - tensor: bf16 tensor to quantize - - Returns: - (fp8_tensor, scale) where scale is a scalar float32 tensor. - Dequantize: fp8_tensor.to(bf16) * scale - """ - finfo = torch.finfo(FP8_DTYPE) - amax = tensor.abs().amax().clamp(min=1e-12) - scale = amax / finfo.max - fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) - return fp8_tensor, scale.to(torch.float32).reshape(1) - - -# --------------------------------------------------------------------------- -# MXFP4 quantization (aiter native: block-32, fp4x2 + fp8_e8m0 dtypes) -# Uses aiter.utility.fp4_utils.dynamic_mxfp4_quant -# --------------------------------------------------------------------------- - -def quantize_mxfp4(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - MXFP4 block-wise quantization using aiter's dynamic_mxfp4_quant. - - Block size = 32. Each block gets an E8M0 scale factor. - Two FP4 E2M1 values are packed per byte. - - Args: - tensor: bf16 tensor of shape [B, M, N] (N must be divisible by 32) - - Returns: - (fp4_data, scale_e8m0) - - fp4_data: shape [B, M, N//2] in aiter_dtypes.fp4x2 - - scale_e8m0: shape [B*M, ceil(N/32)] padded, in aiter_dtypes.fp8_e8m0 - """ - orig_shape = tensor.shape # (B, M, N) - B, M, N = orig_shape - - # dynamic_mxfp4_quant expects 2D: (B*M, N) - tensor_2d = tensor.reshape(B * M, N) - fp4_data_2d, scale_e8m0 = dynamic_mxfp4_quant(tensor_2d) - - # Reshape fp4_data back to 3D: (B, M, N//2) - fp4_data = fp4_data_2d.view(B, M, N // 2) - - return fp4_data, scale_e8m0 - - -def dequantize_mxfp4( - fp4_data: torch.Tensor, - scale_e8m0: torch.Tensor, - orig_shape: tuple, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - Dequantize MXFP4 tensor using aiter utilities. - - Note: dynamic_mxfp4_quant may pad both row and block dimensions in scale_e8m0. - We trim scales to match the actual data dimensions. - - Args: - fp4_data: packed FP4 data, shape [B, M, N//2] in fp4x2 or uint8 - scale_e8m0: E8M0 block scale factors (possibly padded) in fp8_e8m0 - orig_shape: original (B, M, N) for reshaping - dtype: output dtype - - Returns: - Dequantized tensor of shape orig_shape. - """ - B, M, N = orig_shape - num_rows = B * M - block_size = 32 - num_blocks = N // block_size # actual blocks needed (e.g. 576/32 = 18) - - # Unpack FP4 to float32: mxfp4_to_f32 expects (..., N//2) -> (..., N) - fp4_data_2d = fp4_data.reshape(num_rows, N // 2) - float_vals = mxfp4_to_f32(fp4_data_2d) # (num_rows, N) - - # Convert E8M0 scales to float32 and trim padded dimensions - scale_f32 = e8m0_to_f32(scale_e8m0) # (padded_rows, padded_blocks) - scale_f32 = scale_f32[:num_rows, :num_blocks] # (num_rows, num_blocks) - - # Apply block scales - float_vals_blocked = float_vals.view(num_rows, num_blocks, block_size) - scaled = float_vals_blocked * scale_f32.unsqueeze(-1) - - return scaled.view(B, M, N).to(dtype) - - -# --------------------------------------------------------------------------- -# Persistent mode metadata helpers -# --------------------------------------------------------------------------- - -def _make_mla_decode_metadata( - batch_size: int, - max_q_len: int, - nhead: int, - nhead_kv: int, - q_dtype: torch.dtype, - kv_dtype: torch.dtype, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, - num_kv_splits: int = NUM_KV_SPLITS, -): - """Allocate and populate work buffers for persistent mla_decode_fwd.""" - info = get_mla_metadata_info_v1( - batch_size, max_q_len, nhead, q_dtype, kv_dtype, - is_sparse=False, fast_mode=False, - num_kv_splits=num_kv_splits, intra_batch_mode=True, - ) - work = [torch.empty(s, dtype=t, device="cuda") for s, t in info] - (work_metadata, work_indptr, work_info_set, - reduce_indptr, reduce_final_map, reduce_partial_map) = work - - # Populate the metadata buffers - get_mla_metadata_v1( - qo_indptr, kv_indptr, kv_last_page_len, - nhead // nhead_kv, # num_heads_per_head_k - nhead_kv, # num_heads_k - True, # is_causal - work_metadata, work_info_set, work_indptr, - reduce_indptr, reduce_final_map, reduce_partial_map, - page_size=PAGE_SIZE, - kv_granularity=max(PAGE_SIZE, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, - max_split_per_batch=num_kv_splits, - intra_batch_mode=True, - dtype_q=q_dtype, - dtype_kv=kv_dtype, - ) - - return { - "work_meta_data": work_metadata, - "work_indptr": work_indptr, - "work_info_set": work_info_set, - "reduce_indptr": reduce_indptr, - "reduce_final_map": reduce_final_map, - "reduce_partial_map": reduce_partial_map, - } - - -# --------------------------------------------------------------------------- -# Aiter reference kernel (decode only) -# --------------------------------------------------------------------------- - -def _aiter_mla_decode( - q: torch.Tensor, - kv_buffer: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - config: dict, - q_scale: torch.Tensor | None = None, - kv_scale: torch.Tensor | None = None, -) -> torch.Tensor: - """ - MLA decode attention using aiter persistent-mode kernel. - - Supports multiple Q/KV dtype combinations: - - Q_DTYPE="fp8": fp8 Q + fp8 KV (a8w8) — fastest on MI355X - - Q_DTYPE="bf16": bf16 Q + bf16 KV (a16w16) — highest precision - - q: (total_q, num_heads, 576) fp8 or bf16 - kv_buffer: (total_kv, 1, 576) fp8 or bf16 - q_scale: scalar float32 (required for fp8 Q, None for bf16) - kv_scale: scalar float32 (required for fp8 KV, None for bf16) - """ - batch_size = config["batch_size"] - nq = config["num_heads"] - nkv = config["num_kv_heads"] - dq = config["qk_head_dim"] - dv = config["v_head_dim"] - q_seq_len = config["q_seq_len"] - - total_kv_len = int(kv_indptr[-1].item()) - kv_indices = torch.arange(total_kv_len, dtype=torch.int32, device="cuda") - - # Reshape kv_buffer to 4D for aiter: (total_kv, page_size, nhead_kv, dim) - kv_buffer_4d = kv_buffer.view(kv_buffer.shape[0], PAGE_SIZE, nkv, kv_buffer.shape[-1]) - - max_q_len = q_seq_len - kv_last_page_len = (kv_indptr[1:] - kv_indptr[:-1]).to(torch.int32) - - # Build persistent-mode metadata - meta = _make_mla_decode_metadata( - batch_size, max_q_len, nq, nkv, - q.dtype, kv_buffer.dtype, - qo_indptr, kv_indptr, kv_last_page_len, - num_kv_splits=NUM_KV_SPLITS, - ) - - o = torch.empty((q.shape[0], nq, dv), dtype=torch.bfloat16, device="cuda") - mla_decode_fwd( - q.view(-1, nq, dq), - kv_buffer_4d, - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - max_q_len, - page_size=PAGE_SIZE, - nhead_kv=nkv, - sm_scale=SM_SCALE, - logit_cap=0.0, - num_kv_splits=NUM_KV_SPLITS, - q_scale=q_scale, - kv_scale=kv_scale, - intra_batch_mode=True, - **meta, - ) - return o - - -# --------------------------------------------------------------------------- -# generate_input / ref_kernel / check_implementation -# --------------------------------------------------------------------------- - -def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> input_t: - """ - Generate absorbed q and compressed kv_buffer for MLA decode. - - Returns all three KV cache formats in kv_data dict: - kv_data = { - "bf16": Tensor — (total_kv, 1, 576) bfloat16 - "fp8": (Tensor, Tensor) — kv_buffer fp8 + scalar scale - "mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale - } - """ - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - - total_q = batchsize * qseqlen - total_kv = batchsize * kvseqlen - - # Absorbed query: (total_q, num_heads, 576) bf16 - q = torch.randn( - (total_q, NUM_HEADS, QK_HEAD_DIM), - dtype=torch.bfloat16, device="cuda", generator=gen, - ) - - # Compressed KV buffer: (total_kv, 1, 576) bf16 — the source of truth - kv_buffer_bf16 = torch.randn( - (total_kv, NUM_KV_HEADS, QK_HEAD_DIM), - dtype=torch.bfloat16, device="cuda", generator=gen, - ) - - # Quantize KV to fp8 - kv_buffer_fp8, kv_scale_fp8 = quantize_fp8(kv_buffer_bf16) - - # Quantize KV to mxfp4 - kv_buffer_mxfp4, kv_scale_mxfp4 = quantize_mxfp4(kv_buffer_bf16) - - # All three KV formats: bf16 is a Tensor, fp8/mxfp4 are (Tensor, Tensor) tuples - kv_data = { - "bf16": kv_buffer_bf16, - "fp8": (kv_buffer_fp8, kv_scale_fp8), - "mxfp4": (kv_buffer_mxfp4, kv_scale_mxfp4), - } - - qo_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * qseqlen - kv_indptr = torch.arange(0, batchsize + 1, dtype=torch.int32, device="cuda") * kvseqlen - - config = { - "batch_size": batchsize, - "num_heads": NUM_HEADS, - "num_kv_heads": NUM_KV_HEADS, - "qk_head_dim": QK_HEAD_DIM, - "kv_lora_rank": KV_LORA_RANK, - "qk_rope_head_dim": QK_ROPE_HEAD_DIM, - "v_head_dim": V_HEAD_DIM, - "q_seq_len": qseqlen, - "kv_seq_len": kvseqlen, - "sm_scale": SM_SCALE, - } - - return (q, kv_data, qo_indptr, kv_indptr, config) - - -def ref_kernel(data: input_t) -> output_t: - """Reference MLA decode attention. Uses Q_DTYPE and KV_DTYPE to select kernel variant.""" - q, kv_data, qo_indptr, kv_indptr, config = data - - # Resolve Q - if Q_DTYPE == "fp8": - q_input, q_scale = quantize_fp8(q) - else: - q_input, q_scale = q, None - - # Resolve KV - if KV_DTYPE == "fp8": - kv_buffer_fp8, kv_scale = kv_data["fp8"] - kv_input = kv_buffer_fp8 - else: - kv_input, kv_scale = kv_data["bf16"], None - - return _aiter_mla_decode( - q_input, kv_input, qo_indptr, kv_indptr, config, - q_scale=q_scale, kv_scale=kv_scale, - ) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-01) diff --git a/problems/amd_202602/mixed-mla/submission.py b/problems/amd_202602/mixed-mla/submission.py deleted file mode 100644 index fba8b760..00000000 --- a/problems/amd_202602/mixed-mla/submission.py +++ /dev/null @@ -1,299 +0,0 @@ -# gpumode leaderboard reference -""" -Reference implementation for MLA (Multi-head Latent Attention) decode kernel. - -Uses aiter MLA kernels (mla_decode_fwd) as the reference. -DeepSeek R1 forward_absorb MLA: absorbed q (576), compressed kv_buffer (576), -output v_head_dim = kv_lora_rank = 512. - -The input provides: - q: (total_q, 16, 576) bfloat16 — absorbed query - kv_data: dict with KV cache in three formats: - "bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision - "fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized - "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 + fp8_e8m0 — block-32 quantized - The reference quantizes Q to fp8 on-the-fly inside ref_kernel. - -The reference kernel quantizes Q to fp8 on-the-fly and uses fp8 KV (a8w8 kernel), -which is ~2-3x faster than bf16 on MI355X with negligible accuracy loss. - -Decode only — persistent mode with get_mla_metadata_v1. -""" - -import torch -import torch.nn.functional as F -from task import input_t, output_t -from utils import make_match_reference - -from aiter.mla import mla_decode_fwd -from aiter import dtypes as aiter_dtypes -from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 -from aiter.utility.fp4_utils import ( - dynamic_mxfp4_quant, - mxfp4_to_f32, - e8m0_to_f32, -) - -# --------------------------------------------------------------------------- -# DeepSeek R1 latent MQA constants (forward_absorb path) -# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json -# --------------------------------------------------------------------------- -NUM_HEADS = 16 -NUM_KV_HEADS = 1 -KV_LORA_RANK = 512 -QK_ROPE_HEAD_DIM = 64 -QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # 576 -V_HEAD_DIM = KV_LORA_RANK # 512 -SM_SCALE = 1.0 / (QK_HEAD_DIM ** 0.5) - -PAGE_SIZE = 1 -NUM_KV_SPLITS = 32 - -# FP8 dtype (platform-specific via aiter) -FP8_DTYPE = aiter_dtypes.fp8 - -# Query dtype for the reference kernel: "fp8" or "bf16" -Q_DTYPE = "fp8" - -# KV cache dtype for the reference kernel: "fp8" or "bf16" -KV_DTYPE = "fp8" - - -# --------------------------------------------------------------------------- -# FP8 quantization (sglang style: dynamic per-tensor) -# --------------------------------------------------------------------------- -def quantize_fp8(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Dynamic per-tensor FP8 quantization (following sglang scaled_fp8_quant). - - Args: - tensor: bf16 tensor to quantize - - Returns: - (fp8_tensor, scale) where scale is a scalar float32 tensor. - Dequantize: fp8_tensor.to(bf16) * scale - """ - finfo = torch.finfo(FP8_DTYPE) - amax = tensor.abs().amax().clamp(min=1e-12) - scale = amax / finfo.max - fp8_tensor = (tensor / scale).clamp(min=finfo.min, max=finfo.max).to(FP8_DTYPE) - return fp8_tensor, scale.to(torch.float32).reshape(1) - - -# --------------------------------------------------------------------------- -# MXFP4 quantization (aiter native: block-32, fp4x2 + fp8_e8m0 dtypes) -# Uses aiter.utility.fp4_utils.dynamic_mxfp4_quant -# --------------------------------------------------------------------------- - -def quantize_mxfp4(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - MXFP4 block-wise quantization using aiter's dynamic_mxfp4_quant. - - Block size = 32. Each block gets an E8M0 scale factor. - Two FP4 E2M1 values are packed per byte. - - Args: - tensor: bf16 tensor of shape [B, M, N] (N must be divisible by 32) - - Returns: - (fp4_data, scale_e8m0) - - fp4_data: shape [B, M, N//2] in aiter_dtypes.fp4x2 - - scale_e8m0: shape [B*M, ceil(N/32)] padded, in aiter_dtypes.fp8_e8m0 - """ - orig_shape = tensor.shape # (B, M, N) - B, M, N = orig_shape - - # dynamic_mxfp4_quant expects 2D: (B*M, N) - tensor_2d = tensor.reshape(B * M, N) - fp4_data_2d, scale_e8m0 = dynamic_mxfp4_quant(tensor_2d) - - # Reshape fp4_data back to 3D: (B, M, N//2) - fp4_data = fp4_data_2d.view(B, M, N // 2) - - return fp4_data, scale_e8m0 - - -def dequantize_mxfp4( - fp4_data: torch.Tensor, - scale_e8m0: torch.Tensor, - orig_shape: tuple, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - Dequantize MXFP4 tensor using aiter utilities. - - Note: dynamic_mxfp4_quant may pad both row and block dimensions in scale_e8m0. - We trim scales to match the actual data dimensions. - - Args: - fp4_data: packed FP4 data, shape [B, M, N//2] in fp4x2 or uint8 - scale_e8m0: E8M0 block scale factors (possibly padded) in fp8_e8m0 - orig_shape: original (B, M, N) for reshaping - dtype: output dtype - - Returns: - Dequantized tensor of shape orig_shape. - """ - B, M, N = orig_shape - num_rows = B * M - block_size = 32 - num_blocks = N // block_size # actual blocks needed (e.g. 576/32 = 18) - - # Unpack FP4 to float32: mxfp4_to_f32 expects (..., N//2) -> (..., N) - fp4_data_2d = fp4_data.reshape(num_rows, N // 2) - float_vals = mxfp4_to_f32(fp4_data_2d) # (num_rows, N) - - # Convert E8M0 scales to float32 and trim padded dimensions - scale_f32 = e8m0_to_f32(scale_e8m0) # (padded_rows, padded_blocks) - scale_f32 = scale_f32[:num_rows, :num_blocks] # (num_rows, num_blocks) - - # Apply block scales - float_vals_blocked = float_vals.view(num_rows, num_blocks, block_size) - scaled = float_vals_blocked * scale_f32.unsqueeze(-1) - - return scaled.view(B, M, N).to(dtype) - - -# --------------------------------------------------------------------------- -# Persistent mode metadata helpers -# --------------------------------------------------------------------------- - -def _make_mla_decode_metadata( - batch_size: int, - max_q_len: int, - nhead: int, - nhead_kv: int, - q_dtype: torch.dtype, - kv_dtype: torch.dtype, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - kv_last_page_len: torch.Tensor, - num_kv_splits: int = NUM_KV_SPLITS, -): - """Allocate and populate work buffers for persistent mla_decode_fwd.""" - info = get_mla_metadata_info_v1( - batch_size, max_q_len, nhead, q_dtype, kv_dtype, - is_sparse=False, fast_mode=False, - num_kv_splits=num_kv_splits, intra_batch_mode=True, - ) - work = [torch.empty(s, dtype=t, device="cuda") for s, t in info] - (work_metadata, work_indptr, work_info_set, - reduce_indptr, reduce_final_map, reduce_partial_map) = work - - # Populate the metadata buffers - get_mla_metadata_v1( - qo_indptr, kv_indptr, kv_last_page_len, - nhead // nhead_kv, # num_heads_per_head_k - nhead_kv, # num_heads_k - True, # is_causal - work_metadata, work_info_set, work_indptr, - reduce_indptr, reduce_final_map, reduce_partial_map, - page_size=PAGE_SIZE, - kv_granularity=max(PAGE_SIZE, 16), - max_seqlen_qo=max_q_len, - uni_seqlen_qo=max_q_len, - fast_mode=False, - max_split_per_batch=num_kv_splits, - intra_batch_mode=True, - dtype_q=q_dtype, - dtype_kv=kv_dtype, - ) - - return { - "work_meta_data": work_metadata, - "work_indptr": work_indptr, - "work_info_set": work_info_set, - "reduce_indptr": reduce_indptr, - "reduce_final_map": reduce_final_map, - "reduce_partial_map": reduce_partial_map, - } - - -# --------------------------------------------------------------------------- -# Aiter reference kernel (decode only) -# --------------------------------------------------------------------------- - -def _aiter_mla_decode( - q: torch.Tensor, - kv_buffer: torch.Tensor, - qo_indptr: torch.Tensor, - kv_indptr: torch.Tensor, - config: dict, - q_scale: torch.Tensor | None = None, - kv_scale: torch.Tensor | None = None, -) -> torch.Tensor: - """ - MLA decode attention using aiter persistent-mode kernel. - - Supports multiple Q/KV dtype combinations: - - Q_DTYPE="fp8": fp8 Q + fp8 KV (a8w8) — fastest on MI355X - - Q_DTYPE="bf16": bf16 Q + bf16 KV (a16w16) — highest precision - - q: (total_q, num_heads, 576) fp8 or bf16 - kv_buffer: (total_kv, 1, 576) fp8 or bf16 - q_scale: scalar float32 (required for fp8 Q, None for bf16) - kv_scale: scalar float32 (required for fp8 KV, None for bf16) - """ - batch_size = config["batch_size"] - nq = config["num_heads"] - nkv = config["num_kv_heads"] - dq = config["qk_head_dim"] - dv = config["v_head_dim"] - q_seq_len = config["q_seq_len"] - total_kv_len = int(kv_indptr[-1].item()) - - # Reshape kv_buffer to 4D for aiter: (total_kv, page_size, nhead_kv, dim) - kv_buffer_4d = kv_buffer.view(kv_buffer.shape[0], PAGE_SIZE, nkv, kv_buffer.shape[-1]) - - max_q_len = q_seq_len - kv_indices = torch.arange(total_kv_len, dtype=torch.int32, device="cuda") - kv_last_page_len = (kv_indptr[1:] - kv_indptr[:-1]).to(torch.int32) - meta = _make_mla_decode_metadata( - batch_size, max_q_len, nq, nkv, - q.dtype, kv_buffer.dtype, - qo_indptr, kv_indptr, kv_last_page_len, - num_kv_splits=NUM_KV_SPLITS, - ) - - o = torch.empty((q.shape[0], nq, dv), dtype=torch.bfloat16, device="cuda") - mla_decode_fwd( - q.view(-1, nq, dq), - kv_buffer_4d, - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - max_q_len, - page_size=PAGE_SIZE, - nhead_kv=nkv, - sm_scale=SM_SCALE, - logit_cap=0.0, - num_kv_splits=NUM_KV_SPLITS, - q_scale=q_scale, - kv_scale=kv_scale, - intra_batch_mode=True, - **meta, - ) - return o - -def custom_kernel(data: input_t) -> output_t: - """Reference MLA decode attention. Uses Q_DTYPE and KV_DTYPE to select kernel variant.""" - q, kv_data, qo_indptr, kv_indptr, config = data - - # Resolve Q - if Q_DTYPE == "fp8": - q_input, q_scale = quantize_fp8(q) - else: - q_input, q_scale = q, None - - # Resolve KV - if KV_DTYPE == "fp8": - kv_buffer_fp8, kv_scale = kv_data["fp8"] - kv_input = kv_buffer_fp8 - else: - kv_input, kv_scale = kv_data["bf16"], None - return _aiter_mla_decode( - q_input, kv_input, qo_indptr, kv_indptr, config, - q_scale=q_scale, kv_scale=kv_scale, - ) \ No newline at end of file diff --git a/problems/amd_202602/mixed-mla/task.py b/problems/amd_202602/mixed-mla/task.py deleted file mode 100644 index 7aff7b6a..00000000 --- a/problems/amd_202602/mixed-mla/task.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from typing import TypeVar, TypedDict, Union - -# DeepSeek R1 MLA forward_absorb format: -# -# Input: (q, kv_data, qo_indptr, kv_indptr, config) -# q: (total_q, num_heads, qk_head_dim) bfloat16 -# kv_data: dict with three KV cache formats: -# "bf16": Tensor (total_kv, 1, 576) bfloat16 -# "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale -# "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale -# qo_indptr: (batch_size + 1,) int32 -# kv_indptr: (batch_size + 1,) int32 -# config: dict with MLA parameters -# -# where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576 -# -# Output: attention output tensor (total_q, num_heads, v_head_dim) bfloat16 -# where v_head_dim = kv_lora_rank = 512 -# -# The kv_buffer stores the compressed KV representation: -# - Full 576 dims used as keys (for Q@K^T score computation) -# - First 512 dims (kv_lora_rank) used as values (for output computation) - -input_t = TypeVar( - "input_t", - bound=tuple[torch.Tensor, dict, torch.Tensor, torch.Tensor, dict], -) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec(TypedDict): - batchsize: int - qseqlen: int - kvseqlen: int - seed: int diff --git a/problems/amd_202602/mixed-mla/task.yml b/problems/amd_202602/mixed-mla/task.yml deleted file mode 100644 index c0a5d5a6..00000000 --- a/problems/amd_202602/mixed-mla/task.yml +++ /dev/null @@ -1,95 +0,0 @@ -# name: mla-py - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement a custom MLA (Multi-head Latent Attention) decode kernel optimized for MI355X. - - This is the inner attention kernel from DeepSeek R1's forward_absorb MLA path. - The absorbed query and compressed KV cache are provided directly — you only need to - implement the **attention** computation with variable-length batching (indptr). - - The reference uses aiter a8w8 MLA decode kernel (mla_decode_fwd, fp8 Q + fp8 KV, - persistent mode), which is ~2-3x faster than bf16 on MI355X. - - DeepSeek R1 forward_absorb MLA config: - - num_heads = 16 (query heads, after TP split) - - num_kv_heads = 1 (shared latent KV head) - - kv_lora_rank = 512 - - qk_rope_head_dim = 64 - - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim) - - v_head_dim = 512 (= kv_lora_rank, output dim) - - sm_scale = 1/sqrt(576) - - dtype: q=bfloat16 - - decode only (q_seq_len=1, kv_seq_len up to 8k) - - KV buffer format (forward_absorb): - - Full 576 dims are used as keys (for Q@K^T score computation) - - First 512 dims (kv_lora_rank) are used as values (for output computation) - - Input tuple: (q, kv_data, qo_indptr, kv_indptr, config) - - q: (total_q, 16, 576) bfloat16 — absorbed query - - kv_data: dict with three KV cache formats: - kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16 - kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale - kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale - - qo_indptr: (batch_size+1,) int32 — query segment pointers - - kv_indptr: (batch_size+1,) int32 — KV segment pointers - - config: dict with MLA parameters - - Return: - - attention output: (total_q, 16, 512) bfloat16 - - Key optimization opportunities: - 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16) - - Fuse dequantization with attention to skip bf16 materialization - 2. Custom kernel with tighter memory access patterns - 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads - 4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload - 5. Variable-length batching: indptr-based segmented attention - 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values - - The ranking criteria is the geometric mean of the benchmark results. - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 900 -benchmark_timeout: 900 -ranked_timeout: 1200 - -tests: - # bs=4 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220} - # bs=32 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} - # bs=64 - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} - # bs=256 - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} - -benchmarks: - # bs=4 - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217} - - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220} - # bs=32 - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412} - - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415} - # bs=64 - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357} - - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360} - # bs=256 - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823} - - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826} - -ranking_by: "geom" diff --git a/problems/amd_202602/moe-mxfp4/README.md b/problems/amd_202602/moe-mxfp4/README.md deleted file mode 100644 index ab664f7d..00000000 --- a/problems/amd_202602/moe-mxfp4/README.md +++ /dev/null @@ -1,198 +0,0 @@ -# MXFP4 Mixture-of-Experts (MoE) Fused Kernel - -## Description - -Implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. - -The kernel fuses the complete MoE forward pass into a 2-stage pipeline: -1. **Stage 1**: MXFP4 GEMM (gate+up projection) + SwiGLU activation -2. **Stage 2**: MXFP4 GEMM (down projection) + weighted reduction across top-k experts - -The reference uses **AITER `fused_moe`** with `QuantType.per_1x32` (MXFP4 block scaling, block_size=32). - -## DeepSeek-R1 MoE Architecture - -| Parameter | Value | Notes | -|---|---|---| -| hidden_size | 7168 | Model hidden dimension | -| moe_intermediate_size | 2048 | Per-expert intermediate dimension | -| n_routed_experts | 256 | Routed experts (EP-off) or 32 (EP-on, 8-way split) | -| n_shared_experts | 1 | Always selected with weight=1.0 | -| top_k (routed) | 8 | Routed experts per token | -| total_top_k | 9 | 8 routed + 1 shared | -| MoE layers | 58 | Layers 3–60 | - -## Kernel Flow - -For each token `i` and each assigned expert `j`: - -``` -(1) Quant activations: hidden_states -> MXFP4 (aiter per-1x32 dynamic quantization) - -(2) Stage 1 GEMM + SwiGLU activation: - gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - intermediate = SiLU(gate) * up # SwiGLU activation -> [d_expert] - (W_gate and W_up are fused as gate_up_weight: one a4w4 GEMM + fused activation) - -(3) Stage 2 GEMM: - expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] - -(4) Weighted reduction: - output_i += w_ij * expert_out # accumulate across top_k experts -``` - -All weight GEMMs are **a4w4** (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). -The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. - -## Weight Layout & Pre-shuffling - -Weights are provided in two layouts: - -| Layout | Description | Use case | -|---|---|---| -| **Raw** | Original MXFP4 quantized weights | PyTorch reference / custom kernels | -| **Pre-shuffled** | `shuffle_weight(w, layout=(16,16))` + `e8m0_shuffle(scale)` | AITER CK kernel (tile-coalesced layout) | - -The (16,16) shuffle rearranges weight tiles for coalesced memory access by CK GEMM instructions. -Scale shuffling (`e8m0_shuffle`) reorders E8M0 block scales to match the shuffled weight layout. - -You may use either layout — raw weights if you implement your own tiling, or pre-shuffled weights -for direct use with AITER/CK kernels. - -## MXFP4 Quantization Details - -| Property | Value | -|---|---| -| FP4 format | E2M1 — values `[0, 0.5, 1, 1.5, 2, 3, 4, 6]`, max = 6.0 | -| Scale format | E8M0 — exponent-only (power-of-2 scale) | -| Block size | 32 elements per scale | -| Packing | 2 FP4 values per byte (`fp4x2`): low nibble = even index, high nibble = odd index | -| Padding | Dimensions padded to 256-alignment for CK kernel | - -### aiter dtype reference - -| Logical type | aiter dtype | PyTorch native (if available) | Fallback | -|---|---|---|---| -| fp4x2 | `aiter.dtypes.fp4x2` | `torch.float4_e2m1fn_x2` | `torch.uint8` | -| fp8_e8m0 | `aiter.dtypes.fp8_e8m0` | `torch.float8_e8m0fnu` | `torch.uint8` | - -## Input - -A tuple of tensors and a config dict: - -``` -(hidden_states, - gate_up_weight, down_weight, # fp4x2 raw - gate_up_weight_scale, down_weight_scale, # e8m0 raw - gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled - gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled - topk_weights, topk_ids, - config) -``` - -### Tensor shapes - -| Tensor | Shape | Dtype | Notes | -|---|---|---|---| -| `hidden_states` | `[M, d_hidden]` | bfloat16 | Input activations (M = batch of tokens) | -| `gate_up_weight` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Fused gate+up weights, raw | -| `down_weight` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Down projection weights, raw | -| `gate_up_weight_scale` | `[E, 2*d_expert_pad, d_hidden_pad//32]` | e8m0 | Block scales for gate_up, raw | -| `down_weight_scale` | `[E, d_hidden_pad, d_expert_pad//32]` | e8m0 | Block scales for down, raw | -| `gate_up_weight_shuffled` | `[E, 2*d_expert_pad, d_hidden_pad//2]` | fp4x2 | Pre-shuffled for CK | -| `down_weight_shuffled` | `[E, d_hidden_pad, d_expert_pad//2]` | fp4x2 | Pre-shuffled for CK | -| `gate_up_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | -| `down_weight_scale_shuffled` | `[padded, flat]` | e8m0 | Pre-shuffled for CK | -| `topk_weights` | `[M, total_top_k]` | float32 | Routing weights | -| `topk_ids` | `[M, total_top_k]` | int32 | Expert indices (see below) | - -### topk_ids format - -- First `n_experts_per_token` columns: routed expert IDs `[0, n_routed_experts)` -- Last `n_shared_experts` columns: shared expert IDs `[n_routed_experts, n_routed_experts + n_shared_experts)` -- Shared experts are always selected with weight = 1.0 - -### config dict - -```python -config = { - "d_hidden": int, # hidden dimension (e.g. 7168) - "d_expert": int, # expert intermediate dimension (e.g. 2048 or 256) - "d_hidden_pad": int, # d_hidden padded to 256-alignment - "d_expert_pad": int, # d_expert padded to 256-alignment - "n_routed_experts": int, # number of routed experts - "n_shared_experts": int, # number of shared experts (1) - "n_experts_per_token": int, # routed top-k (8) - "total_top_k": int, # routed + shared (9) - "bs": int, # batch size (number of tokens) -} -``` - -## Output - -``` -output: [M, d_hidden] bfloat16 -``` - -## Reference Performance - -AITER `fused_moe` with MXFP4 (E includes shared expert, top_k = routed + shared): - -| bs | E | d_hidden | d_expert | top_k | time (us) | -|---|---|---|---|---|---| -| 4 | 257 | 7168 | 256 | 9 | 46.9 | -| 64 | 257 | 7168 | 256 | 9 | 187.7 | -| 256 | 257 | 7168 | 256 | 9 | 245.7 | -| 64 | 33 | 7168 | 2048 | 9 | 220.6 | -| 256 | 33 | 7168 | 2048 | 9 | 276.4 | -| 1024 | 33 | 7168 | 2048 | 9 | 572.2 | - -## Optimization Opportunities - -The AITER CK `fused_moe` kernel is already well-optimized. To beat it, consider: - -1. **Custom tiling / scheduling**: The CK kernel uses a fixed tile strategy. For small batch sizes - (bs=4) or highly skewed expert distributions, a custom schedule may reduce idle waves. - -2. **Activation quantization fusion**: The reference quantizes activations separately before the - GEMM. Fusing dynamic MXFP4 quantization into the Stage 1 GEMM prologue saves one global - memory round-trip. - -3. **Inter-stage fusion**: The reference runs Stage 1 and Stage 2 as separate kernel launches. - Fusing both stages (gate_up GEMM → SwiGLU → down GEMM → accumulate) into a single kernel - eliminates the intermediate buffer write/read between stages. - -4. **Expert-parallel wave scheduling**: With 257 experts but only 9 active per token, most - expert slots are empty. A work-stealing or compact-dispatch strategy can minimize wasted - wavefronts. - -5. **Shared expert fusion**: The shared expert is always selected for all tokens. It could be - computed as a dense GEMM (no routing overhead) and fused with the routed expert reduction. - -6. **Split-K for large M**: For bs=1024 with EP-on (E=33, d_expert=2048), the GEMMs are large - enough to benefit from split-K parallelism within each expert. - -## Accuracy - -Submissions are checked against the AITER reference with `rtol=1e-2, atol=1e-2`. - -## Benchmark Cases - -### EP-off (all 257 experts on 1 GPU, d_expert=256) - -| bs | E | d_hidden | d_expert | top_k | -|---|---|---|---|---| -| 4 | 257 | 7168 | 256 | 9 | -| 64 | 257 | 7168 | 256 | 9 | -| 256 | 257 | 7168 | 256 | 9 | - -### EP-on (EP=8, 33 experts per GPU, d_expert=2048) - -| bs | E | d_hidden | d_expert | top_k | -|---|---|---|---|---| -| 64 | 33 | 7168 | 2048 | 9 | -| 256 | 33 | 7168 | 2048 | 9 | -| 1024 | 33 | 7168 | 2048 | 9 | - -Ranking is by **geometric mean** of benchmark latencies. diff --git a/problems/amd_202602/moe-mxfp4/eval.py b/problems/amd_202602/moe-mxfp4/eval.py deleted file mode 100644 index a03a3cc5..00000000 --- a/problems/amd_202602/moe-mxfp4/eval.py +++ /dev/null @@ -1,382 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional - -import torch.cuda - -from utils import set_seed, clear_l2_cache_large as clear_l2_cache -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, 'w') - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a+b)*(a+b+1)//2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - if val == "true": - val = True - elif val == "false": - val = False - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg)**2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), - worst=float(worst)) - - -def _clone_data(data): - """ - Return data as-is (no cloning). - - aiter's fused_moe produces incorrect results when weight tensors are - cloned to different memory addresses (same values, different output). - Since fused_moe does not mutate its inputs, skipping the clone is safe. - """ - return data - - -def wrap_check_implementation(data, submission_output): - # Old version returned just a single string, new version - # returns (bool, str); this function ensures compatibility with old - # problem definitions. - result = check_implementation(data, submission_output) - if isinstance(result, tuple): - return result - else: - return not bool(result), result - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - data = generate_input(**test.args) - torch.cuda.synchronize() - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return wrap_check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - # generate input data once - data = generate_input(**test.args) - check_copy = _clone_data(data) - # first, one obligatory correctness check - output = custom_kernel(data) - good, message = wrap_check_implementation(check_copy, output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - if recheck: - # ensure we use a different seed for every benchmark - if "seed" in test.args: - test.args["seed"] += 13 - - data = generate_input(**test.args) - check_copy = _clone_data(data) - torch.cuda.synchronize() - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - - if recheck: - good, message = check_implementation(check_copy, output) - if not good: - return message - - del output - durations.append(start_event.elapsed_time(end_event) * 1e6) - - if i > 1: - total_bm_duration = time.perf_counter_ns() - bm_start_time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: - break - - return calculate_stats(durations) - - -def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, - max_time_ns: float): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 1000, 50e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from torch.profiler import profile, record_function, ProfilerActivity - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - mp_context = multiprocessing.get_context('spawn') - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/amd_202602/moe-mxfp4/reference.py b/problems/amd_202602/moe-mxfp4/reference.py deleted file mode 100644 index e22ff34d..00000000 --- a/problems/amd_202602/moe-mxfp4/reference.py +++ /dev/null @@ -1,299 +0,0 @@ -from utils import make_match_reference -from task import input_t, output_t -import torch -import torch.nn.functional as F -from typing import Dict, Tuple, Optional -import math - -import aiter -from aiter import ActivationType, QuantType, dtypes -from aiter.fused_moe import fused_moe -from aiter.utility import fp4_utils -from aiter.ops.shuffle import shuffle_weight - - -# ────────────────────────────────────────────────────────────────────── -# Constants -# ────────────────────────────────────────────────────────────────────── -MXFP4_BLOCK_SIZE = 32 -PAD_ALIGN = 256 - - -def _pad_to(x: int, align: int) -> int: - return (x + align - 1) // align * align - - -# ────────────────────────────────────────────────────────────────────── -# generate_input: produce all tensors needed by ref_kernel -# -# Models DeepSeek-R1 MoE layer shapes: -# - d_hidden = 7168 -# - d_expert = moe_intermediate_size (full=2048, or TP-split) -# - E_total = n_routed_experts + n_shared_experts (257 or 33) -# - top_k_total = nexpertspertoken + nsharedexperts (8+1=9) -# -# ────────────────────────────────────────────────────────────────────── -def generate_input( - dhidden: int, - dexpert: int, - nroutedexperts: int, - nexpertspertoken: int, - nsharedexperts: int, - bs: int, - seed: int, -) -> input_t: - d_hidden = dhidden - d_expert = dexpert - n_routed_experts = nroutedexperts - n_shared_experts = nsharedexperts - routed_top_k = nexpertspertoken - total_top_k = routed_top_k + n_shared_experts # e.g. 8 + 1 = 9 - E_total = n_routed_experts + n_shared_experts # e.g. 256 + 1 = 257 - M = bs # number of tokens - - # Padded dimensions (AITER MXFP4 requires 256-alignment) - d_hidden_pad = _pad_to(d_hidden, PAD_ALIGN) - d_expert_pad = _pad_to(d_expert, PAD_ALIGN) - - config = { - "d_hidden": d_hidden, - "d_expert": d_expert, - "d_hidden_pad": d_hidden_pad, - "d_expert_pad": d_expert_pad, - "n_routed_experts": n_routed_experts, - "n_shared_experts": n_shared_experts, - "n_experts_per_token": routed_top_k, - "total_top_k": total_top_k, - "bs": M, - } - - gen = torch.Generator(device='cuda') - gen.manual_seed(seed) - - # ── hidden_states [M, d_hidden] ── - hidden_states = torch.randn( - (M, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, - ) - - # ── Router: softmax top-k (routed experts only) ── - router_weight = torch.randn( - (n_routed_experts, d_hidden), device='cuda', dtype=torch.bfloat16, generator=gen, - ) / math.sqrt(d_hidden) - router_logits = F.linear(hidden_states, router_weight) # [M, n_routed_experts] - scores = router_logits.softmax(dim=-1) - routed_weights, routed_ids = torch.topk( - scores, k=routed_top_k, dim=-1, sorted=False - ) - routed_weights = routed_weights.to(torch.float32) - routed_ids = routed_ids.to(torch.int32) - - # ── Append shared expert(s): always selected, weight = 1.0 ── - # Shared experts are indexed as n_routed_experts, n_routed_experts+1, ... - shared_ids = torch.arange( - n_routed_experts, E_total, device='cuda', dtype=torch.int32 - ).unsqueeze(0).expand(M, -1) # [M, n_shared_experts] - shared_weights = torch.ones( - (M, n_shared_experts), device='cuda', dtype=torch.float32 - ) - - topk_ids = torch.cat([routed_ids, shared_ids], dim=-1) # [M, total_top_k] - topk_weights = torch.cat([routed_weights, shared_weights], dim=-1) # [M, total_top_k] - - gate_up_bf16 = torch.randn( - (E_total, 2 * d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen, - ) / math.sqrt(d_hidden) - down_bf16 = torch.randn( - (E_total, d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen, - ) / math.sqrt(d_expert) - - torch_quant = aiter.get_torch_quant(QuantType.per_1x32) - gate_up_weight, gate_up_weight_scale = torch_quant(gate_up_bf16, quant_dtype=dtypes.fp4x2) - down_weight, down_weight_scale = torch_quant(down_bf16, quant_dtype=dtypes.fp4x2) - gate_up_weight = gate_up_weight.view(E_total, 2 * d_expert_pad, d_hidden_pad // 2) - down_weight = down_weight.view(E_total, d_hidden_pad, d_expert_pad // 2) - - gate_up_weight_shuffled = shuffle_weight(gate_up_weight, layout=(16, 16)) - down_weight_shuffled = shuffle_weight(down_weight, layout=(16, 16)) - gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale) - down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale) - - return ( - hidden_states, # [M, d_hidden] bf16 - gate_up_weight, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - down_weight, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - gate_up_weight_scale, # [E_total, 2*d_expert_pad, scale_K] e8m0 (raw) - down_weight_scale, # [E_total, d_hidden_pad, scale_K] e8m0 (raw) - gate_up_weight_shuffled, # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) - down_weight_shuffled, # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) - gate_up_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) - down_weight_scale_shuffled, # [padded, flat] e8m0 (pre-shuffled) - topk_weights, # [M, total_top_k] float32 - topk_ids, # [M, total_top_k] int32 - config, - ) - - - - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) -# ────────────────────────────────────────────────────────────────────── -def _dequant_mxfp4(weight_fp4, scale_e8m0): - """ - Dequantize MXFP4 weight to float32. - - weight_fp4: [N, K//2] fp4x2 (raw, not shuffled) - scale_e8m0: [padded_N, ceil(K/32)] e8m0 (M-dim padded to 256-align by dynamic_mxfp4_quant) - - Returns: [N, K] float32 - """ - # fp4x2 -> float32 lookup: [N, K] - w_f32 = fp4_utils.mxfp4_to_f32(weight_fp4) # [N, K] - # e8m0 -> float32 power-of-2 scale: [padded_N, scale_K] - s_f32 = fp4_utils.e8m0_to_f32(scale_e8m0) # [padded_N, scale_K] - N, K = w_f32.shape - # Trim scale rows to match weight rows (scale M-dim is padded to 256) - s_f32 = s_f32[:N, :] - # Broadcast scale across block_size=32 columns - s_f32 = s_f32.repeat_interleave(MXFP4_BLOCK_SIZE, dim=-1)[:, :K] # [N, K] - return w_f32 * s_f32 - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel_pytorch: pure PyTorch implementation (dequant + matmul) -# will not run. only for reference -# ────────────────────────────────────────────────────────────────────── -def ref_kernel_pytorch(data: input_t) -> output_t: - """ - Pure PyTorch reference: dequantize MXFP4 weights -> bf16 matmul -> SwiGLU -> matmul. - Uses the raw (un-shuffled) weights. - """ - ( - hidden_states, # [M, d_hidden] bf16 - gate_up_weight, # [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 - down_weight, # [E, d_hidden_pad, d_expert_pad//2] fp4x2 - gate_up_weight_scale, # [E, 2*d_expert_pad, scale_K] e8m0 - down_weight_scale, # [E, d_hidden_pad, scale_K] e8m0 - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, # [M, top_k] float32 - topk_ids, # [M, top_k] int32 - config, - ) = data - - d_hidden = config["d_hidden"] - d_expert = config["d_expert"] - d_hidden_pad = config["d_hidden_pad"] - d_expert_pad = config["d_expert_pad"] - M = hidden_states.shape[0] - top_k = topk_ids.shape[1] - E = gate_up_weight.shape[0] - - # Dequantize all expert weights to float32 - # gate_up: [E, 2*d_expert_pad, d_hidden_pad] -> trim to [E, 2*d_expert, d_hidden] - # down: [E, d_hidden_pad, d_expert_pad] -> trim to [E, d_hidden, d_expert] - gate_up_dq = torch.stack([ - _dequant_mxfp4(gate_up_weight[e], gate_up_weight_scale[e]) - for e in range(E) - ]) # [E, 2*d_expert_pad, d_hidden_pad] - gate_up_dq = gate_up_dq[:, :2 * d_expert, :d_hidden].to(torch.bfloat16) - - down_dq = torch.stack([ - _dequant_mxfp4(down_weight[e], down_weight_scale[e]) - for e in range(E) - ]) # [E, d_hidden_pad, d_expert_pad] - down_dq = down_dq[:, :d_hidden, :d_expert].to(torch.bfloat16) - - # Split gate_up -> gate [E, d_expert, d_hidden], up [E, d_expert, d_hidden] - gate_w, up_w = gate_up_dq.chunk(2, dim=1) # each [E, d_expert, d_hidden] - - # Per-token MoE forward - output = torch.zeros((M, d_hidden), dtype=torch.bfloat16, device=hidden_states.device) - - for i in range(M): - x = hidden_states[i] # [d_hidden] - for k in range(top_k): - eid = topk_ids[i, k].item() - w = topk_weights[i, k].item() - - # Stage 1: gate_proj + up_proj + SwiGLU - gate_out = F.silu(x @ gate_w[eid].T) # [d_expert] - up_out = x @ up_w[eid].T # [d_expert] - intermediate = gate_out * up_out # [d_expert] - - # Stage 2: down_proj - # down_dq[eid] is [d_hidden, d_expert], .T is [d_expert, d_hidden] - expert_out = intermediate @ down_dq[eid].T # [d_hidden] - - output[i] += w * expert_out - - return output - - - -# ────────────────────────────────────────────────────────────────────── -# ref_kernel: calls AITER fused_moe with MXFP4 quantized weights -# ────────────────────────────────────────────────────────────────────── -def ref_kernel(data: input_t) -> output_t: - """ - Reference implementation using AITER's fused_moe kernel with MXFP4 quantized weights. - - Input data tuple (E = n_routed_experts + n_shared_experts, total_top_k = routed + shared): - hidden_states: [M, d_hidden] bf16 - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) - gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw, before shuffle) - down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw, before shuffle) - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled) - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled) - gate_up_weight_scale_shuffled:[padded, flat] e8m0 (pre-shuffled) - down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled) - topk_weights: [M, total_top_k] float32 - topk_ids: [M, total_top_k] int32 - config: dict - - Returns: - output: [M, d_hidden] bf16 - """ - ( - hidden_states, - gate_up_weight, - down_weight, - gate_up_weight_scale, - down_weight_scale, - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, - topk_ids, - config, - ) = data - - hidden_pad = config["d_hidden_pad"] - config["d_hidden"] - intermediate_pad = config["d_expert_pad"] - config["d_expert"] - - output = fused_moe( - hidden_states, - gate_up_weight_shuffled, - down_weight_shuffled, - topk_weights, - topk_ids, - expert_mask=None, - activation=ActivationType.Silu, - quant_type=QuantType.per_1x32, # MXFP4 uses per_1x32 block scaling - doweight_stage1=False, - w1_scale=gate_up_weight_scale_shuffled, - w2_scale=down_weight_scale_shuffled, - a1_scale=None, - a2_scale=None, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ) - - return output - - - -check_implementation = make_match_reference(ref_kernel, rtol=5e-2, atol=5e-2) diff --git a/problems/amd_202602/moe-mxfp4/submission.py b/problems/amd_202602/moe-mxfp4/submission.py deleted file mode 100644 index a771b32c..00000000 --- a/problems/amd_202602/moe-mxfp4/submission.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -from typing import Dict -from task import input_t, output_t - -from aiter import ActivationType, QuantType -from aiter.fused_moe import fused_moe - - -def custom_kernel(data: input_t) -> output_t: - """ - Submission template for DeepSeek-R1 MXFP4 MoE kernel. - - Input data tuple: - hidden_states: [M, d_hidden] bf16 - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - gate_up_weight_scale: [E, 2*d_expert_pad, scale_K] e8m0 (raw) - down_weight_scale: [E, d_hidden_pad, scale_K] e8m0 (raw) - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) - gate_up_weight_scale_shuffled:[padded, flat] e8m0 (shuffled) - down_weight_scale_shuffled: [padded, flat] e8m0 (shuffled) - topk_weights: [M, total_top_k] float32 - topk_ids: [M, total_top_k] int32 - config: dict - - Returns: - output: [M, d_hidden] bf16 - """ - ( - hidden_states, - gate_up_weight, - down_weight, - gate_up_weight_scale, - down_weight_scale, - gate_up_weight_shuffled, - down_weight_shuffled, - gate_up_weight_scale_shuffled, - down_weight_scale_shuffled, - topk_weights, - topk_ids, - config, - ) = data - - hidden_pad = config["d_hidden_pad"] - config["d_hidden"] - intermediate_pad = config["d_expert_pad"] - config["d_expert"] - - output = fused_moe( - hidden_states, - gate_up_weight_shuffled, - down_weight_shuffled, - topk_weights, - topk_ids, - expert_mask=None, - activation=ActivationType.Silu, - quant_type=QuantType.per_1x32, - doweight_stage1=False, - w1_scale=gate_up_weight_scale_shuffled, - w2_scale=down_weight_scale_shuffled, - a1_scale=None, - a2_scale=None, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - ) - - return output diff --git a/problems/amd_202602/moe-mxfp4/task.py b/problems/amd_202602/moe-mxfp4/task.py deleted file mode 100644 index a19edc83..00000000 --- a/problems/amd_202602/moe-mxfp4/task.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import TypeVar, Tuple, Dict -import torch - -input_t = TypeVar("input_t", bound=Tuple[ - torch.Tensor, # hidden_states [M, d_hidden] - torch.Tensor, # gate_up_weight [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw) - torch.Tensor, # down_weight [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw) - torch.Tensor, # gate_up_weight_scale [E, 2*d_expert_pad, scale_K] e8m0 (raw) - torch.Tensor, # down_weight_scale [E, d_hidden_pad, scale_K] e8m0 (raw) - torch.Tensor, # gate_up_weight_shuffled [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (shuffled) - torch.Tensor, # down_weight_shuffled [E, d_hidden_pad, d_expert_pad//2] fp4x2 (shuffled) - torch.Tensor, # gate_up_weight_scale_shuffled [padded, flat] e8m0 (shuffled) - torch.Tensor, # down_weight_scale_shuffled [padded, flat] e8m0 (shuffled) - torch.Tensor, # topk_weights [M, total_top_k] - torch.Tensor, # topk_ids [M, total_top_k] - Dict, # config -]) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec: - dhidden: int # hidden dimension (7168 for DeepSeek-R1) - dexpert: int # intermediate dimension per expert (per partition) - nroutedexperts: int # number of local routed experts on this GPU - nexpertspertoken: int # top-k routed experts per token (8 for DeepSeek-R1) - nsharedexperts: int # number of shared experts (1 for DeepSeek-R1), always selected - bs: int # batch size = number of tokens in this batch - seed: int diff --git a/problems/amd_202602/moe-mxfp4/task.yml b/problems/amd_202602/moe-mxfp4/task.yml deleted file mode 100644 index c3ac0e21..00000000 --- a/problems/amd_202602/moe-mxfp4/task.yml +++ /dev/null @@ -1,125 +0,0 @@ -# name: 3_moe_mxfp4 - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "eval.py"} - -lang: "py" - -description: | - You will implement a DeepSeek-R1 style MXFP4 Mixture-of-Experts (MoE) fused kernel optimized for AMD Instinct MI355X GPU. - - To be explicit, you will be given a tuple of tensors: - ``` - (hidden_states, - gate_up_weight, down_weight, # fp4x2 raw - gate_up_weight_scale, down_weight_scale, # e8m0 raw - gate_up_weight_shuffled, down_weight_shuffled, # fp4x2 pre-shuffled - gate_up_weight_scale_shuffled, down_weight_scale_shuffled, # e8m0 pre-shuffled - topk_weights, topk_ids, - config) - ``` - where: - * `hidden_states` is M x d_hidden in bfloat16 (the input activations, M = batch of tokens) - * `gate_up_weight` is [E, 2*d_expert_pad, d_hidden_pad//2] in MXFP4 (fp4x2), raw layout. - Fused gate + up projection weights for each expert. E = number of local experts. - * `down_weight` is [E, d_hidden_pad, d_expert_pad//2] in MXFP4 (fp4x2), raw layout. - Down projection weights for each expert. - * `gate_up_weight_scale` is [E, 2*d_expert_pad, d_hidden_pad//32] in E8M0, raw layout. - Block scales (block_size=32) for gate_up_weight. - * `down_weight_scale` is [E, d_hidden_pad, d_expert_pad//32] in E8M0, raw layout. - Block scales for down_weight. - * `gate_up_weight_shuffled` / `down_weight_shuffled` are the same weights shuffled to - (16,16) tile-coalesced layout for the CK kernel. - * `gate_up_weight_scale_shuffled` / `down_weight_scale_shuffled` are the scales after - e8m0_shuffle, flattened to [padded, flat]. - * `topk_weights` is [M, total_top_k] float32: routing weights (routed experts + shared experts). - * `topk_ids` is [M, total_top_k] int32: expert indices. First nexpertspertoken columns are - routed expert ids (0..n_routed-1), last nsharedexperts columns are shared expert ids - (n_routed..n_routed+n_shared-1). Shared experts are always selected with weight=1.0. - * `config` is a dict with: d_hidden, d_expert, d_hidden_pad, d_expert_pad, - n_routed_experts, n_shared_experts, n_experts_per_token, total_top_k, bs. - - Then, the fused_moe kernel flow is: - (1) Quant activations to MXFP4: aiter per-1x32 dynamic quantization of hidden_states. - (2) Stage 1 GEMM + activation (per token i, per assigned expert j): - - gate = x_i @ W_gate_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - - up = x_i @ W_up_j.T # [d_hidden] x [d_expert, d_hidden].T -> [d_expert] - - intermediate = SiLU(gate) * up # SwiGLU activation, -> [d_expert] - (W_gate and W_up are fused as gate_up_weight, so this is one a4w4 GEMM + fused activation) - (3) Stage 2 GEMM: - - expert_out = intermediate @ W_down_j.T # [d_expert] x [d_hidden, d_expert].T -> [d_hidden] - (4) Weighted reduction: - - output_i += w_ij * expert_out # accumulate across top_k experts - All weight GEMMs are a4w4 (MXFP4 activations x MXFP4 weights, per-1x32 block scaling). - The AITER CK kernel fuses all of the above into a 2-stage pipeline across all tokens and experts. - - DeepSeek-R1 MoE specs: - - hidden_size = 7168, moe_intermediate_size = 2048 - - 256 routed experts + 1 shared expert (total 257), top-8 routed + 1 shared = 9 per token - - 58 MoE layers (layer 3-60) - - The shared expert processes ALL tokens unconditionally (weight=1.0) - - d_hidden_pad and d_expert_pad are the dimensions padded to 256-alignment for the CK kernel. - - The ranking criteria is the geometric mean of the benchmark results. - - ``` - The AITER reference performance is (E includes shared expert, top_k = routed + shared): - bs E d_hidden d_expert top_k time[us] - 16 257 7168 256 9 152.7 - 128 257 7168 256 9 239.0 - 512 257 7168 256 9 336.5 - 16 33 7168 512 9 106.2 - 128 33 7168 512 9 141.1 - 512 33 7168 512 9 225.0 - 512 33 7168 2048 9 380.4 - ``` - - Input: - - hidden_states: [M, d_hidden] bf16 - - gate_up_weight: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (raw, before shuffle) - - down_weight: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (raw, before shuffle) - - gate_up_weight_scale: [E, 2*d_expert_pad, d_hidden_pad//32] e8m0 (raw, before shuffle) - - down_weight_scale: [E, d_hidden_pad, d_expert_pad//32] e8m0 (raw, before shuffle) - - gate_up_weight_shuffled: [E, 2*d_expert_pad, d_hidden_pad//2] fp4x2 (pre-shuffled for CK) - - down_weight_shuffled: [E, d_hidden_pad, d_expert_pad//2] fp4x2 (pre-shuffled for CK) - - gate_up_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) - - down_weight_scale_shuffled: [padded, flat] e8m0 (pre-shuffled for CK) - - topk_weights: [M, total_top_k] float32 - - topk_ids: [M, total_top_k] int32 - - config: dict with dimensions - - Output: - - output: [M, d_hidden] bf16 - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 540 -benchmark_timeout: 540 -ranked_timeout: 840 -ranking_by: "geom" - -tests: - - {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 8, "seed": 9371} - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 32, "seed": 2291} - - {"dhidden": 4096, "dexpert": 1536, "nroutedexperts": 64, "nexpertspertoken": 6, "nsharedexperts": 1, "bs": 128, "seed": 81934} - -benchmarks: - # TP=8 - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 9371} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 2291} - - {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} - # TP=4 - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 2291} - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 81934} - - {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} - # EP on - - {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934} diff --git a/problems/amd_202602/mxfp4-mm/reference.py b/problems/amd_202602/mxfp4-mm/reference.py deleted file mode 100644 index e4439c84..00000000 --- a/problems/amd_202602/mxfp4-mm/reference.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. -Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). - -NOTE: Explicitly uses dynamic_mxfp4_quant from aiter.ops.triton.quant (patched in #975) - rather than going through aiter.get_triton_quant, which may dispatch to the - unpatched fp4_utils.py kernel. See ROCm/aiter#974, ROCm/aiter#975. -""" -import torch -from task import input_t, output_t -from utils import make_match_reference -from aiter import QuantType,dtypes -import aiter -from aiter.ops.shuffle import shuffle_weight -from aiter.ops.triton.quant import dynamic_mxfp4_quant # #975-patched kernel -from aiter.utility.fp4_utils import e8m0_shuffle -# K must be divisible by 64 (scale group 32 and fp4 pack 2) -SCALE_GROUP_SIZE = 32 - -def _quant_mxfp4(x, shuffle=True): - x_fp4, bs_e8m0 = dynamic_mxfp4_quant(x) - if shuffle: - bs_e8m0 = e8m0_shuffle(bs_e8m0) - return x_fp4.view(dtypes.fp4x2), bs_e8m0.view(dtypes.fp8_e8m0) - -def generate_input(m: int, n: int, k: int, seed: int):# -> input_t: - """ - Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale. - - Returns: - Tuple of (A, B), both bf16 on cuda. - """ - assert k % 64 == 0, "k must be divisible by 64 (scale group 32 and fp4 pack 2)" - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen) - B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen) - B_q, B_scale_sh = _quant_mxfp4(B, shuffle=True) - # shuffle B(weight) to (16,16) tile coalesced - B_shuffle = shuffle_weight(B_q, layout=(16, 16)) - return (A, B, B_q, B_shuffle, B_scale_sh) - -def run_torch_fp4_mm( - x: torch.Tensor, - w: torch.Tensor, - x_scales: torch.Tensor, - w_scales: torch.Tensor, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - PyTorch reference: dequant MXFP4 + E8M0 scale -> f32 -> mm -> dtype. - Same logic as aiter op_tests/test_gemm_a4w4.run_torch. - x: [m, k//2] fp4 packed, w: [n, k//2] fp4 packed - x_scales: [m, k//32] E8M0, w_scales: [n, k//32] E8M0 - Returns: [m, n] in dtype - """ - from aiter.utility import fp4_utils - - m, _ = x.shape - n, _ = w.shape - # fp4 packed -> f32 - x_f32 = fp4_utils.mxfp4_to_f32(x) - w_f32 = fp4_utils.mxfp4_to_f32(w) - # E8M0 scale: [*, k//32] -> repeat 32 along k -> f32 - x_scales = x_scales[:m].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - x_scales_f32 = fp4_utils.e8m0_to_f32(x_scales) - x_f32 = x_f32 * x_scales_f32 - w_scales = w_scales[:n].repeat_interleave(SCALE_GROUP_SIZE, dim=1) - w_scales_f32 = fp4_utils.e8m0_to_f32(w_scales) - w_f32 = w_f32 * w_scales_f32 - return torch.mm(x_f32, w_f32.T).to(dtype)[:m, :n] - - -def ref_kernel(data: input_t) -> output_t: - """ - Reference: MXFP4 per-1x32 quant on A and B; both PyTorch ref and gemm_a4w4 are given. - Returns gemm_a4w4 for check_implementation. - """ - A, B, B_q, B_shuffle, B_scale_sh = data - A = A.contiguous() - B = B.contiguous() - m, k = A.shape - n, _ = B.shape - - # 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16 - # Per-1x32 MXFP4 quant - # A_q, A_scale = _quant_mxfp4(A, shuffle=False) - # B_q, B_scale = _quant_mxfp4(B, shuffle=False) - - # gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0 - # quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed - # slice to exact shapes [m,k_scale] / [n,k_scale] (quant may return padded scale) - - # k_scale = k // SCALE_GROUP_SIZE - # A_scale = A_scale[:m, :k_scale].contiguous() - # B_scale = B_scale[:n, :k_scale].contiguous() - # out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16) - - # 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105) - A_q, A_scale_sh = _quant_mxfp4(A, shuffle=True) - # to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py - out_gemm = aiter.gemm_a4w4( - A_q, - B_shuffle, - A_scale_sh, - B_scale_sh, - dtype=dtypes.bf16, - bpreshuffle=True, - ) - return out_gemm - -check_implementation = make_match_reference(ref_kernel, rtol=1e-02, atol=1e-02) \ No newline at end of file diff --git a/problems/amd_202602/mxfp4-mm/submission.py b/problems/amd_202602/mxfp4-mm/submission.py deleted file mode 100644 index eaf4050f..00000000 --- a/problems/amd_202602/mxfp4-mm/submission.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C. -Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)). -""" -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference: MXFP4 per-1x32 quant on A; B_shuffle, B_scale_sh from generate_input. - gemm_a4w4 with bpreshuffle=True. - """ - import aiter - from aiter import QuantType, dtypes - from aiter.ops.triton.quant import dynamic_mxfp4_quant - from aiter.utility.fp4_utils import e8m0_shuffle - - def _quant_mxfp4(x, shuffle=True): - x_fp4, bs_e8m0 = dynamic_mxfp4_quant(x) - if shuffle: - bs_e8m0 = e8m0_shuffle(bs_e8m0) - return x_fp4.view(dtypes.fp4x2), bs_e8m0.view(dtypes.fp8_e8m0) - - A, B, B_q, B_shuffle, B_scale_sh = data - A = A.contiguous() - B = B.contiguous() - m, k = A.shape - n, _ = B.shape - - A_q, A_scale_sh = _quant_mxfp4(A, shuffle=True) - out_gemm = aiter.gemm_a4w4( - A_q, - B_shuffle, - A_scale_sh, - B_scale_sh, - dtype=dtypes.bf16, - bpreshuffle=True, - ) - return out_gemm diff --git a/problems/amd_202602/mxfp4-mm/task.py b/problems/amd_202602/mxfp4-mm/task.py deleted file mode 100644 index a258eac0..00000000 --- a/problems/amd_202602/mxfp4-mm/task.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -quant + FP4 GEMM: bf16 A, B -> MXFP4 1x32 per-block quant -> gemm_a4w4 -> bf16 C. -""" -import torch -from typing import TypeVar, TypedDict - -# Input: (A, B, B_q, B_shuffle, B_scale_sh) from generate_input. -# A [m,k], B [n,k] bf16; B_q quantized MXFP4; B_shuffle = shuffle_weight(B_q,(16,16)); B_scale_sh from quant(B, shuffle=True). -# Output: bf16 C [m, n]. -input_t = TypeVar( - "input_t", - bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], -) -output_t = TypeVar("output_t", bound=torch.Tensor) - - -class TestSpec(TypedDict): - m: int - n: int - k: int - seed: int diff --git a/problems/amd_202602/mxfp4-mm/task.yml b/problems/amd_202602/mxfp4-mm/task.yml deleted file mode 100644 index 457a7b47..00000000 --- a/problems/amd_202602/mxfp4-mm/task.yml +++ /dev/null @@ -1,66 +0,0 @@ -# name: mxfp4-mm - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - You will implement a quantize func and block scaled MXFP4 matrix-matrix multiplication kernel optimized for AMD Instinct MI355X GPU. - To be explicit, you will be given a tuple of tensors: - ``` - (A, B, B_q, B_shuffle, B_scale_sh) - ``` - where: - * `A` is M x K in K-major order in bfloat16 - * `B` is N x K in K-major order in bfloat16 - * `B_q` is N x K/2 in K-major order in MXFP4 - * `B_shuffle` is N x K/2 in shuffled order in MXFP4, shuffled to (16,16) tile coalesced - * `B_scale_sh` is * x K/32 in E8M0, * means it will be padded. - - Then, the kernel flow is bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> BF16 C [m,n]. - To be specific, the invocation flow is: - (1) Quant A to MXFP4: aiter.get_triton_quant(QuantType.per_1x32). - (2) GEMM: aiter.gemm_a4w4. - m, n divisible by 64; k divisible by 64. - - The ranking criteria is the geometric mean of the benchmark results. - Pls note that this is the elimination round, whoever rank top5 are selected into the next round, e2e optimization for deepseek-R1-MXFP4 and GPTOSS-MXFP4 mdoel - ``` - The aiter performance is: - M N K time[us] - 4 2880 512 8.198 - 16 2112 7168 20.873 - 32 4096 512 9.462 - 32 2880 512 9.173 - 64 7168 2048 12.738 - 256 3072 1536 12.219 - ``` -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 420 -benchmark_timeout: 420 -ranked_timeout: 600 -ranking_by: "geom" - -tests: - - {"m": 8, "n": 2112, "k": 7168, "seed": 124} - - {"m": 16, "n": 3072, "k": 1536, "seed": 6635} - - {"m": 64, "n": 3072, "k": 1536, "seed": 45} - - {"m": 256, "n": 2880, "k": 512, "seed": 78} - -benchmarks: - - {"m": 4, "n": 2880, "k": 512, "seed": 4565} - - {"m": 16, "n": 2112, "k": 7168, "seed": 15} - - {"m": 32, "n": 4096, "k": 512, "seed": 457} - - {"m": 32, "n": 2880, "k": 512, "seed": 54} - - {"m": 64, "n": 7168, "k": 2048, "seed": 687} - - {"m": 256, "n": 3072, "k": 1536, "seed": 7856} diff --git a/problems/amd_202602/utils.py b/problems/amd_202602/utils.py deleted file mode 100644 index 42f36d30..00000000 --- a/problems/amd_202602/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -import random -from typing import Tuple - -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> Tuple[bool, list[str]]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return False, ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received.to(torch.float32) - expected.to(torch.float32)) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return False, mismatch_details - - return True, [f"Maximum error: {torch.max(diff)}"] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int = 5) -> Tuple[bool, list[str]]: - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return False, mismatch_details - - return True, [] - -def _is_mla_case(data) -> bool: - """ - Detect mixed-MLA style input tuple: - (q, kv_data, qo_indptr, kv_indptr, config) - """ - if not isinstance(data, tuple) or len(data) < 5: - return False - config = data[4] - if not isinstance(config, dict): - return False - mla_keys = { - "num_heads", - "num_kv_heads", - "qk_head_dim", - "kv_lora_rank", - "qk_rope_head_dim", - "v_head_dim", - } - return mla_keys.issubset(config.keys()) - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08, tol_err_ratio=0.05): - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - good, reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - # Only for MLA: aligned with aiter - if (not good) and _is_mla_case(data) and output.shape == expected.shape: - mismatch_mask = ~torch.isclose(output, expected, rtol=rtol, atol=atol) - mismatch_ratio = (mismatch_mask.sum() / output.numel()).item() - if mismatch_ratio <= tol_err_ratio: - return True, ( - f"warning: mismatch_ratio={mismatch_ratio:.6f} " - f"(<= tol_err_ratio={tol_err_ratio}) with rtol={rtol}, atol={atol}" - ) - - if len(reasons) > 0: - return good, "\\n".join(reasons) - - return good, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - -def clear_l2_cache_large(): - dummy = torch.randn((16000, 1024, 1024), device="cuda") - del dummy diff --git a/problems/bioml.yaml b/problems/bioml.yaml index 03dab68f..3761aea6 100644 --- a/problems/bioml.yaml +++ b/problems/bioml.yaml @@ -8,7 +8,7 @@ description: "Popular and important kernels for BioML models like AlphaFold3" problems: - directory: bioml/trimul name: trimul - deadline: "2026-05-09" + deadline: "2025-09-30" gpus: - B200 - H100 diff --git a/problems/helion.yaml b/problems/helion.yaml deleted file mode 100644 index e978396c..00000000 --- a/problems/helion.yaml +++ /dev/null @@ -1,29 +0,0 @@ -name: Helion Kernel Challenge -deadline: "2026-03-14" -description: "GPU kernel challenges inspired by Helion kernel ideas — convolution, quantization, and gated deltanet operators from production LLM architectures." -problems: - - directory: helion/causal_conv1d_py - name: causal_conv1d - deadline: "2026-03-15 01:00" - gpus: - - B200_Nebius - - directory: helion/fp8_quant_py - name: fp8_quant - deadline: "2026-03-15 01:00" - gpus: - - B200_Nebius - - directory: helion/gated_deltanet_chunk_fwd_h_py - name: gated_deltanet_chunk_fwd_h - deadline: "2026-03-15 01:00" - gpus: - - B200_Nebius - - directory: helion/gated_deltanet_chunk_fwd_o_py - name: gated_deltanet_chunk_fwd_o - deadline: "2026-03-15 01:00" - gpus: - - B200_Nebius - - directory: helion/gated_deltanet_recompute_w_u_py - name: gated_deltanet_recompute_w_u - deadline: "2026-03-15 01:00" - gpus: - - B200_Nebius diff --git a/problems/helion/causal_conv1d_py/reference.py b/problems/helion/causal_conv1d_py/reference.py deleted file mode 100644 index 0d2ae2f5..00000000 --- a/problems/helion/causal_conv1d_py/reference.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import torch.nn.functional as F -from task import input_t, output_t -from utils import make_match_reference, DeterministicContext - - -def generate_input(B: int, D: int, S: int, W: int, seed: int) -> input_t: - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - x = torch.randn(B, D, S, dtype=torch.float32, device="cuda", generator=gen).contiguous() - weight = torch.randn(D, W, dtype=torch.float32, device="cuda", generator=gen).contiguous() - bias = torch.randn(D, dtype=torch.float32, device="cuda", generator=gen).contiguous() - return x, weight, bias - - -def ref_kernel(data: input_t) -> output_t: - with DeterministicContext(): - x, weight, bias = data - B, D, S = x.shape - W = weight.shape[1] - - # Causal (left) padding - x_padded = F.pad(x, (W - 1, 0)) - - # Depthwise conv1d (groups=D) - output = F.conv1d( - x_padded, - weight.unsqueeze(1), # [D, 1, W] - bias=bias, - groups=D, - ) - return output - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/helion/causal_conv1d_py/submission.py b/problems/helion/causal_conv1d_py/submission.py deleted file mode 100644 index 92716763..00000000 --- a/problems/helion/causal_conv1d_py/submission.py +++ /dev/null @@ -1,81 +0,0 @@ -from task import input_t, output_t - -import torch -import helion -import helion.language as hl - - -# Per-shape configs: map (B, D, S, W) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # Test shapes - (1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - # Benchmark shapes - (1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config -} - - -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, config=config) - def kernel( - x_pad: torch.Tensor, # (B, D, L) zero-padded input - w: torch.Tensor, # (D, W) filter coefficients - b: torch.Tensor, # (D,) additive offset - ) -> torch.Tensor: - B = x_pad.size(0) - D = x_pad.size(1) - L = x_pad.size(2) - W = hl.specialize(w.size(1)) - N = L - W + 1 - - y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device) - - for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]): - bi = rb.begin - acc1 = hl.zeros([rd, rs], dtype=torch.float32) - acc2 = hl.zeros([rd, rs], dtype=torch.float32) - acc3 = hl.zeros([rd, rs], dtype=torch.float32) - for j in range(W): - c1 = w[rd, j].to(torch.float32) - x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc1 = acc1 + x1 * c1[:, None] - c2 = w[rd, j].to(torch.float32) - x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc2 = acc2 + x2 * c2[:, None] - c3 = w[rd, j].to(torch.float32) - x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32) - acc3 = acc3 + x3 * c3[:, None] - acc = (acc1 + acc2 + acc3) / 3.0 - acc = acc + b[rd].to(torch.float32)[:, None] - y[rb, rd, rs] = acc[None, :, :].to(y.dtype) - - return y - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - x, weight, bias = data - B, D, S = x.shape - W = weight.shape[1] - kernel = _KERNELS[(B, D, S, W)] - pad_zeros = torch.zeros(B, D, W - 1, dtype=x.dtype, device=x.device) - padded = torch.cat([pad_zeros, x], dim=2) - return kernel(padded, weight, bias) diff --git a/problems/helion/causal_conv1d_py/task.py b/problems/helion/causal_conv1d_py/task.py deleted file mode 100644 index 00a02fe6..00000000 --- a/problems/helion/causal_conv1d_py/task.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import TypedDict, TypeVar -import torch - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=torch.Tensor) - -class TestSpec(TypedDict): - B: int - D: int - S: int - W: int - seed: int diff --git a/problems/helion/causal_conv1d_py/task.yml b/problems/helion/causal_conv1d_py/task.yml deleted file mode 100644 index 1f4e8f0b..00000000 --- a/problems/helion/causal_conv1d_py/task.yml +++ /dev/null @@ -1,49 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement a causal depthwise 1D convolution kernel. - - This is a core component of Mamba/Mamba-2 architectures. Each channel is - convolved independently (depthwise) with causal (left) zero-padding so that - output[t] depends only on input[t-W+1:t+1]. - - For each batch b, channel d, and time t: - out[b, d, t] = bias[d] + sum_{k=0}^{W-1} weight[d, k] * x[b, d, t - W + 1 + k] - where out-of-bounds values are treated as zero. - - Input: tuple(x, weight, bias) where: - - x: torch.Tensor of shape [B, D, S] (float32) - - weight: torch.Tensor of shape [D, W] (float32) - - bias: torch.Tensor of shape [D] (float32) - - Output: torch.Tensor of shape [B, D, S] (float32) - -config: - main: "eval.py" - -templates: - Python: "../template.py" - -tests: - - {"B": 1, "D": 64, "S": 64, "W": 4, "seed": 4242} - - {"B": 2, "D": 128, "S": 128, "W": 4, "seed": 5236} - - {"B": 1, "D": 256, "S": 256, "W": 3, "seed": 1001} - - {"B": 1, "D": 128, "S": 64, "W": 8, "seed": 5531} - - {"B": 4, "D": 64, "S": 128, "W": 4, "seed": 9173} - -benchmarks: - - {"B": 1, "D": 1536, "S": 2048, "W": 4, "seed": 2146} - - {"B": 1, "D": 2560, "S": 2048, "W": 4, "seed": 3129} - - {"B": 1, "D": 2560, "S": 4096, "W": 4, "seed": 54352} - -test_timeout: 180 -benchmark_timeout: 180 -ranked_timeout: 420 -ranking_by: "geom" diff --git a/problems/helion/eval.py b/problems/helion/eval.py deleted file mode 100644 index cbc0f1d6..00000000 --- a/problems/helion/eval.py +++ /dev/null @@ -1,578 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math -from pathlib import Path -from typing import Any, Optional - -import torch.cuda - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, 'w') - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a+b)*(a+b+1)//2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z_]\w*):\s*([a-zA-Z_]\w*|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - if val == "true": - val = True - elif val == "false": - val = False - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg)**2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), - worst=float(worst)) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _copy_data_inplace(dst, src): - """ - Recursively copy tensor data from src into dst (same structure, same shapes). - Used to feed new inputs into CUDA graph buffers without recapturing. - """ - if isinstance(dst, torch.Tensor): - dst.copy_(src) - elif isinstance(dst, (tuple, list)): - for d, s in zip(dst, src): - _copy_data_inplace(d, s) - elif isinstance(dst, dict): - for k in dst: - _copy_data_inplace(dst[k], src[k]) - - -def _do_bench_cudagraph(fn, rep_ms=100, return_mode="mean", clear_l2=True): - """ - Benchmark fn using CUDA graphs with optional L2 cache clearing. - Based on triton.testing.do_bench_cudagraph + triton-lang/triton#8384. - - :param fn: Callable to benchmark (no args). - :param rep_ms: Target repetition time per measurement in milliseconds. - :param return_mode: "min", "max", "mean", "median", or "all" (list of ms). - :param clear_l2: If True, flush L2 cache before each invocation and subtract - the flushing overhead from reported times. - :return: Time(s) in milliseconds. - """ - assert return_mode in ["min", "max", "mean", "median", "all"] - - # 256 MB cache tensor — larger than any current GPU L2 - cache = torch.empty(32 * 1024 * 1024, dtype=torch.int64, device="cuda") if clear_l2 else None - - def maybe_clear_cache(): - if cache is not None: - cache.zero_() - - with torch.cuda.stream(torch.cuda.Stream()): - # warmup - maybe_clear_cache() - fn() - - # step 1 — estimate per-call time - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for _ in range(5): - maybe_clear_cache() - fn() - end_event.record() - torch.cuda.synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - n_repeat = max(1, int(rep_ms / estimate_ms)) if estimate_ms > 0 else 1000 - - # step 2 — capture graph with n_repeat unrolled calls - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for _ in range(n_repeat): - maybe_clear_cache() - fn() - torch.cuda.synchronize() - - # step 3 — if L2 clearing enabled, capture a separate graph to measure - # the clearing overhead so we can subtract it - cache_clear_graph = None - if clear_l2: - cache_clear_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(cache_clear_graph): - for _ in range(n_repeat): - maybe_clear_cache() - torch.cuda.synchronize() - - # step 4 — measure - n_retries = 10 - cache_clear_times = [] - total_times = [] - for _ in range(n_retries): - if cache_clear_graph is not None: - s = torch.cuda.Event(enable_timing=True) - e = torch.cuda.Event(enable_timing=True) - s.record() - cache_clear_graph.replay() - e.record() - torch.cuda.synchronize() - cache_clear_times.append(s.elapsed_time(e) / n_repeat) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - g.replay() - end_event.record() - torch.cuda.synchronize() - total_times.append(start_event.elapsed_time(end_event) / n_repeat) - - if clear_l2: - ret = [max(0, t - c) for t, c in zip(total_times, cache_clear_times)] - else: - ret = total_times - - if return_mode == "all": - return ret - elif return_mode == "min": - return min(ret) - elif return_mode == "max": - return max(ret) - elif return_mode == "mean": - return sum(ret) / len(ret) - elif return_mode == "median": - return sorted(ret)[len(ret) // 2] - - -def _run_single_test(test: TestCase): - """ - Runs a single test case via CUDA graph capture + replay. - This validates that the kernel is capturable and produces correct output. - """ - from submission import custom_kernel - from reference import check_implementation, generate_input - - data = generate_input(**test.args) - check_copy = _clone_data(data) - - # Warmup call to trigger JIT compilation (outside graph capture) - _ = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - - # Capture and replay through CUDA graph - input_data = _clone_data(data) - try: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = custom_kernel(input_data) - except Exception as e: - return False, f"Failed to capture kernel in CUDA graph: {e}" - g.replay() - torch.cuda.synchronize() - - return check_implementation(check_copy, output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark(test: TestCase, recheck: bool, rep_ms: int) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - - Correctness is verified via CUDA graph capture + replay first. - Timing only runs if all correctness checks pass. - - :param test: Test case with input arguments. - :param recheck: If True, run additional correctness checks with varying seeds. - :param rep_ms: Target repetition time per measurement in milliseconds. - """ - from submission import custom_kernel - from reference import check_implementation, generate_input - - data = generate_input(**test.args) - check_copy = _clone_data(data) - - # Warmup (JIT compilation) - _ = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - - # Capture in CUDA graph and run initial correctness check - input_data = _clone_data(data) - try: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = custom_kernel(input_data) - except Exception as e: - return f"Failed to capture kernel in CUDA graph: {e}" - g.replay() - torch.cuda.synchronize() - good, message = check_implementation(check_copy, output) - if not good: - return message - - if recheck: - # Reuse the captured graph with new input data for each seed - for i in range(10): - if "seed" in test.args: - test.args["seed"] += 13 - new_data = generate_input(**test.args) - check_copy = _clone_data(new_data) - _copy_data_inplace(input_data, new_data) - g.replay() - torch.cuda.synchronize() - good, message = check_implementation(check_copy, output) - if not good: - return message - - # Timing (only reached if all correctness checks passed) - data = generate_input(**test.args) - fn = lambda: custom_kernel(data) - times_ms = _do_bench_cudagraph(fn, rep_ms=rep_ms, return_mode="all", clear_l2=True) - time.sleep(10) # GPU cooldown to avoid thermal throttling - durations = [t * 1e6 for t in times_ms] # convert ms to ns - return calculate_stats(durations) - - -def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, rep_ms: int): - """ - Run a benchmark in a subprocess. - - :param pool: Process pool. - :param test: TestCase object. - :param recheck: Flag for whether to explicitly check functional correctness. - :param rep_ms: Target repetition time per measurement in milliseconds. - :return: A Stats object or an error string. - """ - return pool.apply(_run_single_benchmark, (test, recheck, rep_ms)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - # warm up - run_single_benchmark(pool, tests[0], False, 20) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 100) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def run_single_profile(test: TestCase) -> str: - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - from reference import generate_input - from torch.profiler import profile, record_function, ProfilerActivity - data = generate_input(**test.args) - torch.cuda.synchronize() - - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - submission_output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def run_profiling(logger: PopcornOutput, tests: list[TestCase]): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test) - logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) - logger.log("check", "pass") - return 0 - - -def run_local(): - """ - Local eval mode: reads task.yml from a problem directory, runs correctness tests - and benchmarks, prints results to stdout. No Popcorn infrastructure needed. - - Usage: python eval.py - mode: test, benchmark, or both - problem_dir: path to the problem directory containing task.yml - """ - import yaml - - if len(sys.argv) < 3: - print("Usage: python eval.py ", file=sys.stderr) - print(" mode: test, benchmark, or both", file=sys.stderr) - print(" problem_dir: path to problem directory containing task.yml", file=sys.stderr) - return 1 - - mode = sys.argv[1] - problem_dir = Path(sys.argv[2]) - - if mode not in ("test", "benchmark", "both"): - print(f"Unknown mode '{mode}'. Use 'test', 'benchmark', or 'both'.", file=sys.stderr) - return 1 - - problem_dir = problem_dir.resolve() - task_path = problem_dir / "task.yml" - if not task_path.exists(): - print(f"Error: task.yml not found in {problem_dir}", file=sys.stderr) - return 1 - - task = yaml.safe_load(task_path.read_text()) - - # chdir into the problem directory so that `from submission import ...` works - os.chdir(problem_dir) - sys.path.insert(0, str(problem_dir)) - - from utils import set_seed - - set_seed(42) - exit_code = 0 - - # --- Correctness tests --- - if mode in ("test", "both"): - tests = [TestCase(args=dict(t), spec=str(t)) for t in task.get("tests", [])] - print(f"Running {len(tests)} correctness tests...") - all_passed = True - for idx, test in enumerate(tests): - good, message = _run_single_test(test) - status = "PASS" if good else "FAIL" - print(f" Test {idx}: {status} {test.spec}") - if not good: - print(f" {message}") - all_passed = False - if all_passed: - print("All tests passed.") - else: - print("Some tests FAILED.") - exit_code = 1 - - # --- Benchmarks --- - if mode in ("benchmark", "both"): - benchmarks = [TestCase(args=dict(t), spec=str(t)) for t in task.get("benchmarks", [])] - print(f"\nRunning {len(benchmarks)} benchmarks...") - - # Warmup - _run_single_benchmark(benchmarks[0], False, 20) - - for idx, bench in enumerate(benchmarks): - result = _run_single_benchmark(bench, False, 100) - if isinstance(result, Stats): - mean_ms = result.mean / 1e6 # Stats stores ns - min_ms = result.best / 1e6 - max_ms = result.worst / 1e6 - print(f" Benchmark {idx}: {mean_ms:.4f} ms (min={min_ms:.4f}, max={max_ms:.4f}) {bench.spec}") - else: - print(f" Benchmark {idx}: FAIL (correctness) {bench.spec}") - print(f" {result}") - exit_code = 1 - - return exit_code - - -def main(): - os.environ["HELION_DISALLOW_AUTOTUNING"] = "1" - fd = os.getenv("POPCORN_FD") - if not fd: - return run_local() - - if len(sys.argv) < 3: - return 2 - - from utils import set_seed - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - mp_context = multiprocessing.get_context('spawn') - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 20) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 200) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log(f"benchmark.{i}.error", str(result)) - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/helion/fp8_quant_py/reference.py b/problems/helion/fp8_quant_py/reference.py deleted file mode 100644 index bcad6943..00000000 --- a/problems/helion/fp8_quant_py/reference.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from task import input_t, output_t -from utils import verbose_allclose - -FP8_MAX = 448.0 -FP8_MIN = -448.0 -FP8_EPS = 1e-10 - - -def generate_input(num_tokens: int, hidden_dim: int, group_size: int, seed: int) -> input_t: - gen = torch.Generator(device="cuda") - gen.manual_seed(seed) - x = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda", generator=gen).contiguous() - x_q = torch.empty(num_tokens, hidden_dim, dtype=torch.float32, device="cuda").contiguous() - x_s = torch.empty(num_tokens, hidden_dim // group_size, dtype=torch.float32, device="cuda").contiguous() - return x, x_q, x_s - - -def ref_kernel(data: input_t) -> output_t: - x, x_q, x_s = data - num_tokens, hidden_dim = x.shape - num_groups = x_s.shape[1] - group_size = hidden_dim // num_groups - - x_f32 = x.float() - x_grouped = x_f32.reshape(num_tokens, num_groups, group_size) - - # Per-group absmax - absmax = x_grouped.abs().amax(dim=-1).clamp(min=FP8_EPS) - - # Scale = absmax / fp8_max - scale = absmax / FP8_MAX - - # Quantize - quantized = (x_grouped / scale.unsqueeze(-1)).clamp(FP8_MIN, FP8_MAX) - quantized = quantized.reshape(num_tokens, hidden_dim) - - x_q[...] = quantized - x_s[...] = scale - return x_q, x_s - - -def check_implementation(data, output): - expected = ref_kernel(data) - expected_q, expected_s = expected - received_q, received_s = output - - reasons_q = verbose_allclose(received_q, expected_q, rtol=1e-3, atol=1e-3) - reasons_s = verbose_allclose(received_s, expected_s, rtol=1e-3, atol=1e-3) - - reasons = [] - if reasons_q: - reasons.append("quantized values mismatch: " + " ".join(reasons_q)) - if reasons_s: - reasons.append("scales mismatch: " + " ".join(reasons_s)) - - if reasons: - return False, " | ".join(reasons) - return True, "" diff --git a/problems/helion/fp8_quant_py/submission.py b/problems/helion/fp8_quant_py/submission.py deleted file mode 100644 index 4b562fa9..00000000 --- a/problems/helion/fp8_quant_py/submission.py +++ /dev/null @@ -1,88 +0,0 @@ -from task import input_t, output_t - -import torch -import helion -import helion.language as hl -from pathlib import Path - - -# Per-shape configs: map (num_tokens, hidden_dim, group_size) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # Test shapes - (1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - # Benchmark shapes - # (1, 4096, 128) already covered above - (16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config -} - - -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, config=config) - def kernel( - data: torch.Tensor, # [N, G] input rows - scales_out: torch.Tensor, # [N] output normalization factors - ) -> torch.Tensor: - nrows = data.size(0) - ncols = hl.specialize(data.size(1)) - MAX_VAL = 448.0 - - qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device) - - for rr in hl.tile(nrows): - row = data[rr, :].to(torch.float32) - - abs1 = torch.abs(row) - amax1 = torch.amax(abs1, -1) - abs2 = torch.abs(row) - amax2 = torch.amax(abs2, -1) - abs3 = torch.abs(row) - amax3 = torch.amax(abs3, -1) - amax = (amax1 + amax2 + amax3) / 3.0 - amax = torch.clamp(amax, min=1e-10) - scale = amax / MAX_VAL - - q1 = row / scale[:, None] - q2 = row / scale[:, None] - q3 = row / scale[:, None] - qout[rr, :] = (q1 + q2 + q3) / 3.0 - scales_out[rr] = scale - - return qout - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - x, x_q, x_s = data - T, H = x.shape - G = x_s.shape[1] - gsz = H // G - N = T * G - - kernel = _KERNELS[(T, H, gsz)] - - flat_in = x.reshape(N, gsz) - flat_s = x_s.reshape(N) - - flat_q = kernel(flat_in, flat_s) - - x_q[...] = flat_q.reshape(T, H) - x_s[...] = flat_s.reshape(T, G) - return x_q, x_s diff --git a/problems/helion/fp8_quant_py/task.py b/problems/helion/fp8_quant_py/task.py deleted file mode 100644 index 8fb6c1f0..00000000 --- a/problems/helion/fp8_quant_py/task.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import TypedDict, TypeVar -import torch - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) - -class TestSpec(TypedDict): - num_tokens: int - hidden_dim: int - group_size: int - seed: int diff --git a/problems/helion/fp8_quant_py/task.yml b/problems/helion/fp8_quant_py/task.yml deleted file mode 100644 index df7c36d5..00000000 --- a/problems/helion/fp8_quant_py/task.yml +++ /dev/null @@ -1,56 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement a per-token-group FP8 E4M3 quantization kernel. - - This is THE standard activation quantization method in production LLM inference - (DeepSeek-V3, Llama 3, Qwen3). It dynamically quantizes activations to FP8 - format with per-group scale factors for W8A8 quantized inference. - - For each group of `group_size` contiguous elements: - 1. absmax = max(|x_group|) - 2. scale = max(absmax, eps) / 448.0 - 3. x_q = clamp(x / scale, -448.0, 448.0) - - Where 448.0 is the max representable value in FP8 E4M3 format. - - NOTE: Output is float32 clamped to FP8 range (for broad GPU compatibility). - - Input: tuple(x, x_q, x_s) where: - - x: torch.Tensor of shape [num_tokens, hidden_dim] (float32) - - x_q: pre-allocated output [num_tokens, hidden_dim] (float32) - - x_s: pre-allocated scales [num_tokens, hidden_dim // group_size] (float32) - - Output: tuple(x_q, x_s) where: - - x_q: quantized values [num_tokens, hidden_dim] (float32, clamped to FP8 range) - - x_s: per-group scale factors [num_tokens, hidden_dim // group_size] (float32) - -config: - main: "eval.py" - -templates: - Python: "../template.py" - -tests: - - {"num_tokens": 1, "hidden_dim": 256, "group_size": 64, "seed": 4242} - - {"num_tokens": 4, "hidden_dim": 512, "group_size": 128, "seed": 5236} - - {"num_tokens": 16, "hidden_dim": 1024, "group_size": 64, "seed": 1001} - - {"num_tokens": 1, "hidden_dim": 4096, "group_size": 128, "seed": 5531} - - {"num_tokens": 8, "hidden_dim": 4096, "group_size": 128, "seed": 9173} - -benchmarks: - - {"num_tokens": 256, "hidden_dim": 4096, "group_size": 128, "seed": 2146} - - {"num_tokens": 256, "hidden_dim": 8192, "group_size": 128, "seed": 3129} - - {"num_tokens": 4096, "hidden_dim": 7168, "group_size": 128, "seed": 54352} - -test_timeout: 180 -benchmark_timeout: 180 -ranked_timeout: 420 -ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py deleted file mode 100644 index 9d9b7204..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -from task import input_t, output_t -from utils import verbose_allclose - -CHUNK_SIZE = 64 - - -def _chunk_local_cumsum_eager(g, chunk_size): - B, T, H = g.shape - C = chunk_size - return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H) - - -def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): - B, T, H, K = k.shape - C = chunk_size - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - kkt = k_c @ k_c.transpose(-1, -2) - strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) - g_diff = g_diff * strict_lower - A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower - return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) - - -def _solve_tril_eager(A, output_dtype): - B, T, H, C = A.shape - NT = T // C - A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - eye = torch.eye(C, device=A.device).expand_as(A_mat) - result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False) - return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype) - - -def _recompute_w_u_fwd_eager(k, v, beta, A, g): - B, T, H, K = k.shape - V = v.shape[-1] - C = A.shape[-1] - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - u_c = A_c @ (v_c * beta_c.unsqueeze(-1)) - w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1)) - w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype) - u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype) - return w, u - - -def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: - torch.manual_seed(seed) - device = "cuda" - k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5 - v = torch.randn(B, T, H, V, dtype=torch.float32, device=device) - beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g = g_inc.cumsum(dim=1) - g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE) - A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE) - A = _solve_tril_eager(A=A, output_dtype=k.dtype) - w, u = _recompute_w_u_fwd_eager(k=k, v=v, beta=beta, A=A, g=g_cumsum) - return k.contiguous(), w.contiguous(), u.contiguous(), g_cumsum.contiguous() - - -def ref_kernel(data: input_t) -> output_t: - k, w, u, g = data - B, T, H, K = k.shape - V = u.shape[-1] - C = CHUNK_SIZE - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - w_c = w.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - u_c = u.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - h_all = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=k.device) - v_new_c = torch.zeros_like(u_c) - h = torch.zeros(B, H, K, V, dtype=torch.float32, device=k.device) - for c in range(NT): - h_all[:, c] = h - v_new_c[:, c] = u_c[:, c] - w_c[:, c] @ h - g_last = g_c[:, c, :, -1] - gate = torch.exp(g_last.unsqueeze(-1) - g_c[:, c]) - v_gated = v_new_c[:, c] * gate.unsqueeze(-1) - h = h * torch.exp(g_last).unsqueeze(-1).unsqueeze(-1) + k_c[:, c].transpose(-1, -2) @ v_gated - v_new_out = v_new_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(u.dtype) - return h_all.to(k.dtype), v_new_out - - -def check_implementation(data, output): - expected = ref_kernel(data) - exp_h, exp_v = expected - got_h, got_v = output - - reasons_h = verbose_allclose(got_h.float(), exp_h.float(), rtol=1e-3, atol=1e-3) - reasons_v = verbose_allclose(got_v.float(), exp_v.float(), rtol=1e-3, atol=1e-3) - - reasons = [] - if reasons_h: - reasons.append("h mismatch: " + " ".join(reasons_h)) - if reasons_v: - reasons.append("v_new mismatch: " + " ".join(reasons_v)) - - if reasons: - return False, " | ".join(reasons) - return True, "" diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py deleted file mode 100644 index 04e0ecfc..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py +++ /dev/null @@ -1,97 +0,0 @@ -from task import input_t, output_t - -import torch -import helion -import helion.language as hl - - -# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config -} - - -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_h_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) - def kernel( - k: torch.Tensor, # [B, T, H, K] - w: torch.Tensor, # [B, T, H, K] - u: torch.Tensor, # [B, T, H, V] - g: torch.Tensor, # [B, T, H] - ) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K = k.shape - V = u.shape[-1] - C = 64 - K = hl.specialize(K) - V = hl.specialize(V) - - NT = (T + C - 1) // C - h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device) - v_out = torch.empty_like(u) - - BH = B * H - - for flat, tv in hl.tile([BH, V], block_size=[1, 8]): - b_idx = flat.begin // H - h_idx = flat.begin % H - state = hl.zeros([K, tv], dtype=torch.float32) - - for tc in hl.tile(T, block_size=C): - chunk_idx = tc.begin // C - t_end = min(tc.begin + C, T) - 1 - - h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype) - - proj1 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj2 = hl.dot( - w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32 - ) - proj = (proj1 + proj2) * 0.5 - diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj - v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype) - - g_end = g[b_idx, t_end, h_idx].to(torch.float32) - g_t = g[b_idx, tc, h_idx].to(torch.float32) - valid = tc.index < T - alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0) - k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None] - - state = state * torch.exp(g_end) - upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32) - state = state + (upd1 + upd2) * 0.5 - - return h_out, v_out - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - k, w, u, g = data - B, T, H, K = k.shape - V = u.shape[-1] - kernel = _KERNELS[(B, T, H, K, V)] - return kernel(k, w, u, g) diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py deleted file mode 100644 index 248a342e..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypedDict, TypeVar -import torch - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) - -class TestSpec(TypedDict): - B: int - T: int - H: int - K: int - V: int - seed: int diff --git a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml b/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml deleted file mode 100644 index 217db171..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml +++ /dev/null @@ -1,62 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement the chunk_fwd_h (inter-chunk state recurrence) kernel for Gated DeltaNet. - - This kernel maintains a hidden state h of shape [K, V] across chunks and computes - v_new (corrected values) for each chunk. It is the sequential bottleneck in the - chunkwise parallel forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). - - The sequence is divided into chunks of BT=64 timesteps. Processing is sequential - across chunks but parallel across (B, H) and within each chunk: - - For each (b, h) pair, starting with h_state = zeros(K, V): - For each chunk c = 0, 1, ..., NT-1: - 1. Store: h_out[b, c, h] = h_state - 2. Compute: v_new = u - w @ h_state - 3. Gate: v_gated[t] = v_new[t] * exp(g[last_t] - g[t]) - 4. Decay: h_state = h_state * exp(g[last_t]) - 5. Update: h_state = h_state + k^T @ v_gated - - Input: tuple(k, w, u, g) where: - - k: torch.Tensor of shape [B, T, H, K] (float32) — keys - - w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys - - u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values - - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate - - Output: tuple(h, v_new) where: - - h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk hidden states - - v_new: torch.Tensor of shape [B, T, H, V] (float32) — corrected values - - Constraint: T must be a multiple of 64. NT = T // 64. - - See also: Helion examples/gdn_fwd_h.py for a related implementation - (simpler variant that returns only h, without v_new output). - -config: - main: "eval.py" - -templates: - Python: "../template.py" - -tests: - - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242} - - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236} - - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001} - -benchmarks: - - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232} - - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052} - - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146} - -test_timeout: 180 -benchmark_timeout: 180 -ranked_timeout: 420 -ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py deleted file mode 100644 index 54be0f2f..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -from task import input_t, output_t -from utils import make_match_reference - -CHUNK_SIZE = 64 - - -def _chunk_local_cumsum_eager(g, chunk_size): - B, T, H = g.shape - C = chunk_size - return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H) - - -def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): - B, T, H, K = k.shape - C = chunk_size - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - kkt = k_c @ k_c.transpose(-1, -2) - strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) - g_diff = g_diff * strict_lower - A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower - return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) - - -def _solve_tril_eager(A, output_dtype): - B, T, H, C = A.shape - NT = T // C - A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - eye = torch.eye(C, device=A.device).expand_as(A_mat) - result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False) - return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype) - - -def _recompute_w_u_fwd_eager(k, v, beta, A, g): - B, T, H, K = k.shape - V = v.shape[-1] - C = A.shape[-1] - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - u_c = A_c @ (v_c * beta_c.unsqueeze(-1)) - w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1)) - w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype) - u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype) - return w, u - - -def _chunk_fwd_h_eager(k, w, u, g): - B, T, H, K = k.shape - V = u.shape[-1] - C = CHUNK_SIZE - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - w_c = w.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - u_c = u.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - h_all = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=k.device) - v_new_c = torch.zeros_like(u_c) - h = torch.zeros(B, H, K, V, dtype=torch.float32, device=k.device) - for c in range(NT): - h_all[:, c] = h - v_new_c[:, c] = u_c[:, c] - w_c[:, c] @ h - g_last = g_c[:, c, :, -1] - gate = torch.exp(g_last.unsqueeze(-1) - g_c[:, c]) - v_gated = v_new_c[:, c] * gate.unsqueeze(-1) - h = h * torch.exp(g_last).unsqueeze(-1).unsqueeze(-1) + k_c[:, c].transpose(-1, -2) @ v_gated - v_new_out = v_new_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(u.dtype) - return h_all.to(k.dtype), v_new_out - - -def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: - torch.manual_seed(seed) - device = "cuda" - q = torch.randn(B, T, H, K, dtype=torch.float32, device=device) - k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5 - v = torch.randn(B, T, H, V, dtype=torch.float32, device=device) - beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g = g_inc.cumsum(dim=1) - g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE) - A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE) - A = _solve_tril_eager(A=A, output_dtype=k.dtype) - w, u = _recompute_w_u_fwd_eager(k=k, v=v, beta=beta, A=A, g=g_cumsum) - h, v_new = _chunk_fwd_h_eager(k=k, w=w, u=u, g=g_cumsum) - return q.contiguous(), k.contiguous(), v_new.contiguous(), h.contiguous(), g_cumsum.contiguous() - - -def ref_kernel(data: input_t) -> output_t: - q, k, v_new, h, g = data - B, T, H, K = q.shape - V = v_new.shape[-1] - C = CHUNK_SIZE - NT = T // C - scale = K ** -0.5 - q_c = q.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1) - causal = torch.tril(torch.ones(C, C, dtype=torch.bool, device=q.device)) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) - g_diff = torch.where(causal, g_diff, torch.zeros_like(g_diff)) - qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_diff) * causal - o = (o_inter + qk @ v_c) * scale - return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-3, atol=1e-3) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py deleted file mode 100644 index eb4de947..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py +++ /dev/null @@ -1,89 +0,0 @@ -from task import input_t, output_t - -import torch -import helion -import helion.language as hl - - -# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check - # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config -} - - -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_o_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) - def kernel( - q: torch.Tensor, # [B, T, H, K] - k: torch.Tensor, # [B, T, H, K] - v: torch.Tensor, # [B, T, H, V] - h: torch.Tensor, # [B, NT, H, K, V] - g: torch.Tensor, # [B, T, H] - scale: float, - ) -> torch.Tensor: - B, T, H, K = q.shape - V = v.shape[-1] - C = 64 - K = hl.specialize(K) - V = hl.specialize(V) - - out = torch.empty_like(v) - - BH = B * H - for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H - c_idx = tile_t.begin // C - - g_vals = g[b_idx, tile_t, h_idx] - q_tile = q[b_idx, tile_t, h_idx, :] - k_tile = k[b_idx, tile_t, h_idx, :] - v_tile = v[b_idx, tile_t, h_idx, :] - - # intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask - qk = hl.dot(q_tile, k_tile.T) - idx = hl.arange(tile_t.block_size) - g_diff = g_vals[:, None] - g_vals[None, :] - causal_mask = idx[:, None] >= idx[None, :] - sim = torch.where(causal_mask, qk * torch.exp(g_diff), 0.0) - local_out = hl.dot(sim.to(v.dtype), v_tile) - - # inter-chunk: (q @ h) * exp(g) - q_s = q_tile * torch.exp(g_vals)[:, None] - global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) - - out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) - - return out - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - q, k, v_new, h, g = data - B, T, H, K = q.shape - V = v_new.shape[-1] - scale = K ** -0.5 - kernel = _KERNELS[(B, T, H, K, V)] - return kernel(q, k, v_new, h, g, scale) diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py deleted file mode 100644 index 08d4b4f6..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypedDict, TypeVar -import torch - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=torch.Tensor) - -class TestSpec(TypedDict): - B: int - T: int - H: int - K: int - V: int - seed: int diff --git a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml b/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml deleted file mode 100644 index 7b8e2a08..00000000 --- a/problems/helion/gated_deltanet_chunk_fwd_o_py/task.yml +++ /dev/null @@ -1,55 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement the chunk_fwd_o (output computation) kernel for Gated DeltaNet. - - This kernel computes the final output by combining inter-chunk (state-based) - and intra-chunk (attention-based) contributions for the chunkwise parallel - forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). - - The sequence is divided into chunks of BT=64 timesteps. For each chunk - independently: - inter = q @ h * exp(g) - intra = causal_mask(q @ k^T * exp(g[:, None] - g[None, :])) @ v_new - output = (inter + intra) * scale - - where scale = K^(-0.5), and causal_mask zeros out entries where row < col. - - Input: tuple(q, k, v_new, h, g) where: - - q: torch.Tensor of shape [B, T, H, K] (float32) — queries - - k: torch.Tensor of shape [B, T, H, K] (float32) — keys - - v_new: torch.Tensor of shape [B, T, H, V] (float32) — corrected values - - h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk states - - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate - - Output: torch.Tensor of shape [B, T, H, V] (float32) - - Constraint: T must be a multiple of 64. NT = T // 64. scale = K^(-0.5). - -config: - main: "eval.py" - -templates: - Python: "../template.py" - -tests: - - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242} - - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236} - - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001} - -benchmarks: - - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232} - - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052} - - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146} - -test_timeout: 180 -benchmark_timeout: 180 -ranked_timeout: 420 -ranking_by: "geom" diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py b/problems/helion/gated_deltanet_recompute_w_u_py/reference.py deleted file mode 100644 index bd7c1507..00000000 --- a/problems/helion/gated_deltanet_recompute_w_u_py/reference.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from task import input_t, output_t -from utils import verbose_allclose - -CHUNK_SIZE = 64 - - -def _chunk_local_cumsum_eager(g, chunk_size): - B, T, H = g.shape - C = chunk_size - return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H) - - -def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): - B, T, H, K = k.shape - C = chunk_size - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - kkt = k_c @ k_c.transpose(-1, -2) - strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) - g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) - g_diff = g_diff * strict_lower - A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower - return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) - - -def _solve_tril_eager(A, output_dtype): - B, T, H, C = A.shape - NT = T // C - A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - eye = torch.eye(C, device=A.device).expand_as(A_mat) - result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False) - return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype) - - -def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: - torch.manual_seed(seed) - device = "cuda" - k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5 - v = torch.randn(B, T, H, V, dtype=torch.float32, device=device) - beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device)) - g = g_inc.cumsum(dim=1) - g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE) - A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE) - A = _solve_tril_eager(A=A, output_dtype=k.dtype) - return k.contiguous(), v.contiguous(), beta.contiguous(), A.contiguous(), g_cumsum.contiguous() - - -def ref_kernel(data: input_t) -> output_t: - k, v, beta, A, g = data - B, T, H, K = k.shape - V = v.shape[-1] - C = A.shape[-1] - NT = T // C - k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) - v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) - beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) - A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) - u_c = A_c @ (v_c * beta_c.unsqueeze(-1)) - w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1)) - w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype) - u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype) - return w, u - - -def check_implementation(data, output): - expected = ref_kernel(data) - exp_w, exp_u = expected - got_w, got_u = output - - reasons_w = verbose_allclose(got_w, exp_w, rtol=1e-3, atol=1e-3) - reasons_u = verbose_allclose(got_u, exp_u, rtol=1e-3, atol=1e-3) - - reasons = [] - if reasons_w: - reasons.append("w mismatch: " + " ".join(reasons_w)) - if reasons_u: - reasons.append("u mismatch: " + " ".join(reasons_u)) - - if reasons: - return False, " | ".join(reasons) - return True, "" diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py b/problems/helion/gated_deltanet_recompute_w_u_py/submission.py deleted file mode 100644 index 07fb0691..00000000 --- a/problems/helion/gated_deltanet_recompute_w_u_py/submission.py +++ /dev/null @@ -1,100 +0,0 @@ -from task import input_t, output_t - -import torch -import helion -import helion.language as hl - - -# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # Test shapes - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check - # Benchmark shapes - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config -} - - -# Optional: add advanced_controls_file to your Config for extra performance (see docs). -# Autotune with autotune_search_acf to find the best ACF, then hardcode it: -# helion.Config(..., advanced_controls_file="/opt/booster_pack/recompute_w_u_fwd_0.acf") - - -# NOTE: This is an intentionally inefficient baseline implementation. -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, dot_precision="ieee", config=config) - def kernel( - k: torch.Tensor, # [B, T, H, K] - v: torch.Tensor, # [B, T, H, V] - beta: torch.Tensor, # [B, T, H] - A: torch.Tensor, # [B, T, H, BT] - g: torch.Tensor, # [B, T, H] - ) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K = k.shape - V = v.shape[-1] - C = hl.specialize(A.shape[-1]) - K = hl.specialize(K) - V = hl.specialize(V) - - w_out = torch.empty_like(k) - u_out = torch.empty_like(v) - - BH = B * H - for flat_bh, rt in hl.tile([BH, T], block_size=[1, C]): - b_idx = flat_bh.begin // H - h_idx = flat_bh.begin % H - - w_acc1 = hl.zeros([rt, K], dtype=torch.float32) - u_acc1 = hl.zeros([rt, V], dtype=torch.float32) - w_acc2 = hl.zeros([rt, K], dtype=torch.float32) - u_acc2 = hl.zeros([rt, V], dtype=torch.float32) - - for ci in range(C): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - - w_acc1 = w_acc1 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc1 = u_acc1 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - - for ci in range(C - 1, -1, -1): - t_ci = rt.begin + ci - a_col = A[b_idx, rt, h_idx, ci].to(torch.float32) - coeff_ci = beta[b_idx, t_ci, h_idx].to(torch.float32) - decay_ci = torch.exp(g[b_idx, t_ci, h_idx].to(torch.float32)) - - k_ci = k[b_idx, t_ci, h_idx, :].to(torch.float32) - v_ci = v[b_idx, t_ci, h_idx, :].to(torch.float32) - - w_acc2 = w_acc2 + a_col[:, None] * (k_ci * coeff_ci * decay_ci)[None, :] - u_acc2 = u_acc2 + a_col[:, None] * (v_ci * coeff_ci)[None, :] - - w_out[b_idx, rt, h_idx, :] = ((w_acc1 + w_acc2) * 0.5).to(k.dtype) - u_out[b_idx, rt, h_idx, :] = ((u_acc1 + u_acc2) * 0.5).to(v.dtype) - - return w_out, u_out - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - k, v, beta, A, g = data - B, T, H, K = k.shape - V = v.shape[-1] - kernel = _KERNELS[(B, T, H, K, V)] - return kernel(k, v, beta, A, g) diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/task.py b/problems/helion/gated_deltanet_recompute_w_u_py/task.py deleted file mode 100644 index 2887eb89..00000000 --- a/problems/helion/gated_deltanet_recompute_w_u_py/task.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypedDict, TypeVar -import torch - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) - -class TestSpec(TypedDict): - B: int - T: int - H: int - K: int - V: int - seed: int diff --git a/problems/helion/gated_deltanet_recompute_w_u_py/task.yml b/problems/helion/gated_deltanet_recompute_w_u_py/task.yml deleted file mode 100644 index 3a8820fc..00000000 --- a/problems/helion/gated_deltanet_recompute_w_u_py/task.yml +++ /dev/null @@ -1,60 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval.py"} - -lang: "py" - -description: | - Implement the recompute_w_u forward kernel for Gated DeltaNet. - - This kernel computes WY-transformed keys (w) and values (u) for the chunkwise - parallel forward pass of Gated DeltaNet (arXiv:2412.06464, ICLR 2025). It is - one of three per-chunk kernels in the forward pipeline. - - The sequence is divided into non-overlapping chunks of BT=64 timesteps. - For each chunk independently: - u = A @ diag(beta) @ v (WY-transformed values) - w = A @ diag(beta * exp(g)) @ k (WY-transformed keys) - - Equivalently: - u = A @ (v * beta[:, None]) - w = A @ (k * beta[:, None] * exp(g)[:, None]) - - where A is a [BT, BT] WY representation matrix per chunk. - - Input: tuple(k, v, beta, A, g) where: - - k: torch.Tensor of shape [B, T, H, K] (float32) — keys - - v: torch.Tensor of shape [B, T, H, V] (float32) — values - - beta: torch.Tensor of shape [B, T, H] (float32) — gating coefficients - - A: torch.Tensor of shape [B, T, H, BT] (float32) — WY matrix (BT=64) - - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate - - Output: tuple(w, u) where: - - w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys - - u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values - - Constraint: T must be a multiple of 64. - -config: - main: "eval.py" - -templates: - Python: "../template.py" - -tests: - - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242} - - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236} - - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001} - -benchmarks: - - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232} - - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052} - - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146} - -test_timeout: 180 -benchmark_timeout: 180 -ranked_timeout: 420 -ranking_by: "geom" diff --git a/problems/helion/template.py b/problems/helion/template.py deleted file mode 100644 index 37d04820..00000000 --- a/problems/helion/template.py +++ /dev/null @@ -1,31 +0,0 @@ -from task import input_t, output_t -import torch -import helion -import helion.language as hl - - -# Per-shape configs: map input shape tuples to optimized helion.Config objects. -# Autotune locally for each shape, then paste the best config here. -# Include all test and benchmark shapes from task.yml. -SHAPE_CONFIGS: dict[tuple, helion.Config] = { - # (shape_dim_1, shape_dim_2, ...): helion.Config(...), # TODO: replace with your config -} - - -def _make_kernel(config: helion.Config): - @helion.kernel(static_shapes=True, config=config) - def kernel(...) -> ...: - # Your Helion kernel implementation - ... - - return kernel - - -_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()} - - -def custom_kernel(data: input_t) -> output_t: - # Extract shape key from input tensors to select the right kernel - # shape_key = (...) - # kernel = _KERNELS[shape_key] - pass diff --git a/problems/linalg.yaml b/problems/linalg.yaml deleted file mode 100644 index 34ed8ce2..00000000 --- a/problems/linalg.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Linear Algebra - -deadline: "" - -description: "Core linear algebra kernels for modern accelerator workloads." - -problems: - - directory: linalg/qr_py - name: qr - deadline: "2026-06-30" - gpus: - - B200 - - directory: linalg/qr_v2 - name: qr_v2 - deadline: "2026-06-30" - gpus: - - B200 diff --git a/problems/linalg/qr_py/eval.py b/problems/linalg/qr_py/eval.py deleted file mode 100644 index cd2c6bd3..00000000 --- a/problems/linalg/qr_py/eval.py +++ /dev/null @@ -1,311 +0,0 @@ -import dataclasses -import math -import multiprocessing -import os -import re -import sys -import time -from pathlib import Path -from typing import Any, Optional - -import torch - -from reference import check_implementation, generate_input -from utils import clear_l2_cache, set_seed - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - - -MAX_ITERATIONS_PER_BENCHMARK = 50 -BENCHMARK_INPUT_BYTES_TARGET = 256 * 1024 * 1024 - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def _combine(a: int, b: int) -> int: - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as exc: - print(f"Could not open test file `{file_name}`: {exc}", file=sys.stderr) - exit(113) - - tests = [] - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in content.splitlines(): - case = {} - for part in line.split(";"): - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - return tests - - -def calculate_stats(durations: list[float]) -> Stats: - runs = len(durations) - total = sum(durations) - avg = total / runs - variance = sum((x - avg) ** 2 for x in durations) - std = math.sqrt(variance / (runs - 1)) if runs > 1 else 0.0 - err = std / math.sqrt(runs) if runs > 0 else 0.0 - return Stats( - runs=runs, - mean=avg, - std=std, - err=err, - best=float(min(durations)), - worst=float(max(durations)), - ) - - -def _clone_data(data): - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - if isinstance(data, list): - return [_clone_data(x) for x in data] - if isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - if isinstance(data, torch.Tensor): - return data.clone() - return data - - -def _run_single_test(test: TestCase): - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return check_implementation(data, output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if good: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - else: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def _make_data_batch(test: TestCase, count: int): - args = dict(test.args) - data_list = [] - for _ in range(count): - if "seed" in args: - args["seed"] += 42 - data_list.append(generate_input(**args)) - return data_list - - -def _benchmark_batch_count(test: TestCase) -> int: - batch = int(test.args.get("batch", 1)) - n = int(test.args.get("n", 1)) - # Input storage is A. Keep the generated batch modest - # because large QR cases are already batched inside a single input. - bytes_per_input = (batch * n * n) * 4 - if bytes_per_input <= 0: - return 1 - return max(1, min(MAX_ITERATIONS_PER_BENCHMARK, BENCHMARK_INPUT_BYTES_TARGET // bytes_per_input)) - - -def _run_single_benchmark( - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -) -> Stats | Any: - from submission import custom_kernel - - data_list = _make_data_batch(test, _benchmark_batch_count(test)) - check_copy = _clone_data(data_list) - - outputs = [custom_kernel(_clone_data(data)) for data in data_list] - for reference_data, output in zip(check_copy, outputs): - good, message = check_implementation(reference_data, output) - if not good: - return message - - durations = [] - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - torch.cuda.synchronize() - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - outputs = [custom_kernel(data) for data in data_list] - end_event.record() - torch.cuda.synchronize() - durations.append(start_event.elapsed_time(end_event) * 1e6 / len(data_list)) - - if recheck: - for reference_data, output in zip(check_copy, outputs): - good, message = check_implementation(reference_data, output) - if not good: - return message - - total_bm_duration = time.perf_counter_ns() - bm_start_time - if i > 1 and total_bm_duration > 1e8: - stats = calculate_stats(durations) - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - run_single_benchmark(pool, tests[0], False, 200, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - # recheck=True: re-validate the output of every timed iteration, not just - # the pre-timing warmup. Without this, the timed loop (which for the - # low-`count` shapes reuses one input object across all repeats) never - # re-checks its outputs, so a kernel that diverges only inside the timed - # region -- e.g. one that caches and replays an output keyed on the - # reused input -- is scored as fast without ever being caught locally. - # `leaderboard` mode already rechecks; this brings `benchmark` mode in - # line so a wrong timed output fails here too. - result = run_single_benchmark(pool, test, True, 200, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - passed = False - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - if mode == "leaderboard": - for test in tests: - run_single_benchmark(pool, test, False, 1000, 5e8) - logger.log("benchmark-count", len(tests)) - passed = True - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, True, 1000, 30e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", str(result)) - passed = False - break - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - if mode == "profile": - logger.log("check", "fail") - logger.log("error", "profile mode is not implemented for qr eval.py") - return 2 - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/linalg/qr_py/reference.py b/problems/linalg/qr_py/reference.py deleted file mode 100644 index fc8ace77..00000000 --- a/problems/linalg/qr_py/reference.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch -from task import input_t, output_t - - -_FACTOR_RTOL_FACTOR = 20.0 -_ORTH_RTOL_FACTOR = 100.0 - - -def _apply_column_scaling(a: torch.Tensor, cond: int) -> torch.Tensor: - # `cond` is a deterministic dynamic-range knob, not an exact condition number. - if cond: - n = a.shape[-1] - scales = torch.logspace(0.0, -float(cond), n, device=a.device, dtype=torch.float32) - return a * scales - return a.contiguous() - - -def _band_mask(n: int, bandwidth: int, device: torch.device) -> torch.Tensor: - idx = torch.arange(n, device=device) - return (idx[:, None] - idx[None, :]).abs() <= bandwidth - - -def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t: - assert batch > 0, "batch must be positive" - assert n > 0, "n must be positive" - assert cond >= 0, "cond must be non-negative" - - device = "cuda" if torch.cuda.is_available() else "cpu" - gen = torch.Generator(device=device) - gen.manual_seed(seed) - - case = case.lower() - a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen) - - if case == "dense": - a = _apply_column_scaling(a, cond) - elif case == "upper": - diag_boost = torch.linspace(1.0, 0.25, n, device=device, dtype=torch.float32) - a = torch.triu(a) - a.diagonal(dim1=-2, dim2=-1).add_(diag_boost) - a = _apply_column_scaling(a, cond) - elif case == "diagonal": - diag = torch.randn((batch, n), device=device, dtype=torch.float32, generator=gen) - diag = diag.sign().clamp(min=0.0).mul(2.0).sub(1.0) * torch.logspace( - 0.0, -float(max(cond, 2)), n, device=device, dtype=torch.float32 - ) - a = torch.diag_embed(diag) - elif case == "rankdef": - rank = max(1, (3 * n) // 4) - a[:, :, rank:] = 0.0 - a = _apply_column_scaling(a, cond) - elif case == "nearrank": - rank = max(1, (3 * n) // 4) - tail = n - rank - if tail > 0: - noise = torch.randn( - (batch, n, tail), device=device, dtype=torch.float32, generator=gen - ) - a[:, :, rank:] = a[:, :, :tail] + 1.0e-5 * noise - a = _apply_column_scaling(a, cond) - elif case == "clustered": - scales = torch.ones((n,), device=device, dtype=torch.float32) - scales[n // 2 :] = 4.0 * torch.finfo(torch.float32).eps - if n >= 8: - lo = max(0, n // 2 - 2) - hi = min(n, n // 2 + 2) - scales[lo:hi] = torch.sqrt(torch.tensor(torch.finfo(torch.float32).eps, device=device)) - a = a * scales - elif case == "band": - bandwidth = max(2, min(32, n // 32)) - a = a * _band_mask(n, bandwidth, device) - diag_boost = torch.linspace(1.0, 0.5, n, device=device, dtype=torch.float32) - a.diagonal(dim1=-2, dim2=-1).add_(diag_boost) - a = _apply_column_scaling(a, cond) - elif case == "nearcollinear": - base = torch.randn((batch, n, 1), device=device, dtype=torch.float32, generator=gen) - noise = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen) - a = base.expand(batch, n, n) + 1.0e-4 * noise - a = _apply_column_scaling(a, cond) - elif case == "rowscale": - row_cond = max(cond, 4) - scales = torch.logspace(0.0, -float(row_cond), n, device=device, dtype=torch.float32) - a = scales.reshape(1, n, 1) * a - else: - raise ValueError(f"unknown QR test case: {case}") - - return a.contiguous() - - -def ref_kernel(data: input_t) -> output_t: - # Starter/reference path: correctness first; submissions compete on speed. - return torch.geqrf(data) - - -def _property_rtol(n: int, factor: float) -> float: - eps = torch.finfo(torch.float32).eps - return factor * max(n, 1) * eps - - -def _scaled_residual( - residual: torch.Tensor, - scale: torch.Tensor, - n: int, -) -> torch.Tensor: - eps = torch.finfo(torch.float32).eps - return residual / (eps * max(n, 1) * scale.clamp_min(1e-30)) - - -def _matrix_l1_norm(value: torch.Tensor) -> torch.Tensor: - return torch.linalg.matrix_norm(value.double(), ord=1, dim=(-2, -1)) - - -def _check_tensor(name: str, value: torch.Tensor, shape: tuple[int, ...], device: torch.device) -> str | None: - if not isinstance(value, torch.Tensor): - return f"{name} must be a torch.Tensor" - if value.shape != shape: - return f"{name} shape must be {shape}, got {tuple(value.shape)}" - if value.dtype != torch.float32: - return f"{name} dtype must be torch.float32, got {value.dtype}" - if value.device != device: - return f"{name} must be on {device}, got {value.device}" - if not torch.isfinite(value).all().item(): - return f"{name} contains NaN or Inf" - return None - - -def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: - a = data - batch, n, _ = a.shape - factor_rtol = _property_rtol(n, _FACTOR_RTOL_FACTOR) - orth_rtol = _property_rtol(n, _ORTH_RTOL_FACTOR) - - if not isinstance(output, tuple) or len(output) != 2: - return False, "output must be a tuple `(H, tau)`" - - h, tau = output - error = _check_tensor("H", h, (batch, n, n), a.device) - if error is not None: - return False, error - error = _check_tensor("tau", tau, (batch, n), a.device) - if error is not None: - return False, error - - q = torch.linalg.householder_product(h, tau) - r = torch.triu(h) - a_check = a.double() - q_check = q.double() - r_check = r.double() - projected = q_check.transpose(-1, -2) @ a_check - factor_residual = _matrix_l1_norm(r_check - projected).amax() - factor_scale = _matrix_l1_norm(a_check).amax() - factor_allowed = factor_rtol * factor_scale - factor_scaled = _scaled_residual(factor_residual, factor_scale, n) - if factor_residual.item() > factor_allowed.item(): - return False, ( - "R - Q.T @ A is too large: " - f"residual={factor_residual.item():.3g}, allowed={factor_allowed.item():.3g}, " - f"scaled={factor_scaled.item():.3g}" - ) - - eye = torch.eye(n, device=a.device, dtype=torch.float64).expand(batch, n, n) - qtq = q_check.transpose(-1, -2) @ q_check - orth_residual = _matrix_l1_norm(qtq - eye).amax() - orth_scale = _matrix_l1_norm(eye).amax() - orth_allowed = orth_rtol * orth_scale - orth_scaled = _scaled_residual(orth_residual, orth_scale, n) - if orth_residual.item() > orth_allowed.item(): - return False, ( - "Q is not orthogonal enough: " - f"residual={orth_residual.item():.3g}, allowed={orth_allowed.item():.3g}, " - f"scaled={orth_scaled.item():.3g}" - ) - - lower = torch.tril(projected, diagonal=-1) - tri_residual = _matrix_l1_norm(lower).amax() - tri_scale = _matrix_l1_norm(a_check).amax() - tri_scaled = _scaled_residual(tri_residual, tri_scale, n) - - recon = q_check @ r_check - recon_residual = _matrix_l1_norm(recon - a_check).amax() - recon_scale = _matrix_l1_norm(a_check).amax() - recon_scaled = _scaled_residual(recon_residual, recon_scale, n) - - return True, ( - f"factor_rtol={factor_rtol:.3g}; " - f"orth_rtol={orth_rtol:.3g}; " - f"scaled_factor_residual={factor_scaled.item():.3g}; " - f"scaled_reconstruction_residual={recon_scaled.item():.3g}; " - f"scaled_triangular_residual={tri_scaled.item():.3g}; " - f"scaled_orthogonality_residual={orth_scaled.item():.3g}; " - f"batch={batch}; n={n}" - ) diff --git a/problems/linalg/qr_py/submission.py b/problems/linalg/qr_py/submission.py deleted file mode 100644 index ac92e0ac..00000000 --- a/problems/linalg/qr_py/submission.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - return torch.geqrf(data) diff --git a/problems/linalg/qr_py/task.py b/problems/linalg/qr_py/task.py deleted file mode 100644 index e0547dcc..00000000 --- a/problems/linalg/qr_py/task.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -from typing import NotRequired, TypeVar, TypedDict - -input_t = TypeVar("input_t", bound=torch.Tensor) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) - - -class TestSpec(TypedDict): - batch: int - n: int - cond: int - seed: int - case: NotRequired[str] diff --git a/problems/linalg/qr_py/task.yml b/problems/linalg/qr_py/task.yml deleted file mode 100644 index 8e935eba..00000000 --- a/problems/linalg/qr_py/task.yml +++ /dev/null @@ -1,100 +0,0 @@ -# name: qr - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../../pmpp_v2/utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "eval.py"} - -lang: "py" - -description: | - Implement batched square compact-Householder QR factorization. - - Input is `A`, a `batch x n x n` CUDA tensor in `torch.float32`. - - Return `(H, tau)` in the same compact Householder convention as - `torch.geqrf(A)`. `H` is a `batch x n x n` FP32 tensor containing `R` in its - upper triangle and Householder vectors below the diagonal. `tau` is a - `batch x n` FP32 tensor containing reflector coefficients. The checker - materializes `Q = torch.linalg.householder_product(H, tau)`, uses - `R_factor = triu(H)`, and validates the LAPACK-style QR factorization - residual `R_factor - Q.T @ A` and orthogonality of `Q`. Since `R_factor` - is extracted with `triu`, triangularity is part of the factorization check: - if `Q.T @ A` has meaningful lower-triangular leakage, then it cannot match - `R_factor`. The checker reports that lower-triangular leakage and the - reconstruction residual as diagnostics. - - This shape set targets optimizer-style matrix statistics where gradients are - viewed as `[for_each..., basis_dim, contracted_dim]`, statistics are formed as - `G @ G.T`, and QR is run on square `basis_dim x basis_dim` matrices. Batched - `512 x 512` is especially important, while `1024`, `2048`, and `4096` cover - larger square factors. - - Test and benchmark specs include a `cond` field. In this task `cond` is a - deterministic input-scaling knob, not an exact requested condition number: - dense cases multiply columns by `logspace(0, -cond, n)`, so larger `cond` - creates a wider dynamic range across columns. Some stress cases use their own - structure, such as rank-deficient, near-rank-deficient, banded, row-scaled, - near-collinear, upper-triangular, or clustered-scale inputs. - - Correctness is a hard gate against the original FP32 input and the FP32 - `torch.geqrf` compact-factor contract. Low-bit FP16, FP8, or NVFP4 work is - allowed only as an internal implementation strategy: returned factors must - still be FP32 and must satisfy the same QR invariants as an FP32 - factorization. Residuals are measured in FP64 to reduce checker noise, but - the target tolerance is still FP32 accuracy. The numerical property tolerance - is purely relative, with no QR `atol`. - The hard gates are the LAPACK-style factor residual, which uses - `rtol = 20 * n * eps32`, and orthogonality, which uses - `rtol = 100 * n * eps32`, each applied to the corresponding matrix L1 norm. - Triangularity is reported as lower-triangular leakage in `Q.T @ A` and is - already implied by the factor residual against `triu(H)`. - - Among passing submissions, ranking is by runtime using the geometric mean of - benchmark cases. We will also celebrate notable submissions beyond the main - leaderboard: the fastest, the most elegant, and the strangest working kernels. - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 240 -benchmark_timeout: 480 -ranked_timeout: 900 -ranking_by: "geom" -gpus: - - B200 - -tests: - - {"batch": 20, "n": 32, "cond": 1, "seed": 53124} - - {"batch": 40, "n": 176, "cond": 1, "seed": 3321} - - {"batch": 40, "n": 352, "cond": 1, "seed": 1200} - - {"batch": 16, "n": 512, "cond": 2, "seed": 32523} - - {"batch": 4, "n": 1024, "cond": 2, "seed": 4327} - - {"batch": 1, "n": 4096, "cond": 1, "seed": 75342} - - {"batch": 16, "n": 512, "cond": 4, "seed": 32524, "case": "dense"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32525, "case": "rankdef"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32526, "case": "clustered"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32527, "case": "band"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32528, "case": "rowscale"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32529, "case": "nearcollinear"} - - {"batch": 4, "n": 1024, "cond": 4, "seed": 4328, "case": "dense"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4329, "case": "rankdef"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4330, "case": "nearrank"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4331, "case": "clustered"} - - {"batch": 2, "n": 2048, "cond": 2, "seed": 224466, "case": "dense"} - - {"batch": 2, "n": 2048, "cond": 0, "seed": 224467, "case": "rankdef"} - - {"batch": 1, "n": 4096, "cond": 0, "seed": 75343, "case": "upper"} - -benchmarks: - - {"batch": 20, "n": 32, "cond": 1, "seed": 43214} - - {"batch": 40, "n": 176, "cond": 1, "seed": 423011} - - {"batch": 40, "n": 352, "cond": 1, "seed": 123456} - - {"batch": 640, "n": 512, "cond": 2, "seed": 1029} - - {"batch": 60, "n": 1024, "cond": 2, "seed": 75342} - - {"batch": 8, "n": 2048, "cond": 1, "seed": 224466} - - {"batch": 2, "n": 4096, "cond": 1, "seed": 32412} diff --git a/problems/linalg/qr_v2/eval.py b/problems/linalg/qr_v2/eval.py deleted file mode 100644 index cd2c6bd3..00000000 --- a/problems/linalg/qr_v2/eval.py +++ /dev/null @@ -1,311 +0,0 @@ -import dataclasses -import math -import multiprocessing -import os -import re -import sys -import time -from pathlib import Path -from typing import Any, Optional - -import torch - -from reference import check_implementation, generate_input -from utils import clear_l2_cache, set_seed - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - - -MAX_ITERATIONS_PER_BENCHMARK = 50 -BENCHMARK_INPUT_BYTES_TARGET = 256 * 1024 * 1024 - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def _combine(a: int, b: int) -> int: - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as exc: - print(f"Could not open test file `{file_name}`: {exc}", file=sys.stderr) - exit(113) - - tests = [] - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in content.splitlines(): - case = {} - for part in line.split(";"): - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - return tests - - -def calculate_stats(durations: list[float]) -> Stats: - runs = len(durations) - total = sum(durations) - avg = total / runs - variance = sum((x - avg) ** 2 for x in durations) - std = math.sqrt(variance / (runs - 1)) if runs > 1 else 0.0 - err = std / math.sqrt(runs) if runs > 0 else 0.0 - return Stats( - runs=runs, - mean=avg, - std=std, - err=err, - best=float(min(durations)), - worst=float(max(durations)), - ) - - -def _clone_data(data): - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - if isinstance(data, list): - return [_clone_data(x) for x in data] - if isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - if isinstance(data, torch.Tensor): - return data.clone() - return data - - -def _run_single_test(test: TestCase): - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - output = custom_kernel(_clone_data(data)) - torch.cuda.synchronize() - return check_implementation(data, output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - return pool.apply(_run_single_test, (test,)) - - -def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if good: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - else: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def _make_data_batch(test: TestCase, count: int): - args = dict(test.args) - data_list = [] - for _ in range(count): - if "seed" in args: - args["seed"] += 42 - data_list.append(generate_input(**args)) - return data_list - - -def _benchmark_batch_count(test: TestCase) -> int: - batch = int(test.args.get("batch", 1)) - n = int(test.args.get("n", 1)) - # Input storage is A. Keep the generated batch modest - # because large QR cases are already batched inside a single input. - bytes_per_input = (batch * n * n) * 4 - if bytes_per_input <= 0: - return 1 - return max(1, min(MAX_ITERATIONS_PER_BENCHMARK, BENCHMARK_INPUT_BYTES_TARGET // bytes_per_input)) - - -def _run_single_benchmark( - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -) -> Stats | Any: - from submission import custom_kernel - - data_list = _make_data_batch(test, _benchmark_batch_count(test)) - check_copy = _clone_data(data_list) - - outputs = [custom_kernel(_clone_data(data)) for data in data_list] - for reference_data, output in zip(check_copy, outputs): - good, message = check_implementation(reference_data, output) - if not good: - return message - - durations = [] - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - torch.cuda.synchronize() - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - outputs = [custom_kernel(data) for data in data_list] - end_event.record() - torch.cuda.synchronize() - durations.append(start_event.elapsed_time(end_event) * 1e6 / len(data_list)) - - if recheck: - for reference_data, output in zip(check_copy, outputs): - good, message = check_implementation(reference_data, output) - if not good: - return message - - total_bm_duration = time.perf_counter_ns() - bm_start_time - if i > 1 and total_bm_duration > 1e8: - stats = calculate_stats(durations) - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): - run_single_benchmark(pool, tests[0], False, 200, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - # recheck=True: re-validate the output of every timed iteration, not just - # the pre-timing warmup. Without this, the timed loop (which for the - # low-`count` shapes reuses one input object across all repeats) never - # re-checks its outputs, so a kernel that diverges only inside the timed - # region -- e.g. one that caches and replays an output keyed on the - # reused input -- is scored as fast without ever being caught locally. - # `leaderboard` mode already rechecks; this brings `benchmark` mode in - # line so a wrong timed output fails here too. - result = run_single_benchmark(pool, test, True, 200, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - passed = False - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - if mode == "leaderboard": - for test in tests: - run_single_benchmark(pool, test, False, 1000, 5e8) - logger.log("benchmark-count", len(tests)) - passed = True - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, True, 1000, 30e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", str(result)) - passed = False - break - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - if mode == "profile": - logger.log("check", "fail") - logger.log("error", "profile mode is not implemented for qr eval.py") - return 2 - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/linalg/qr_v2/reference.py b/problems/linalg/qr_v2/reference.py deleted file mode 100644 index 539f9094..00000000 --- a/problems/linalg/qr_v2/reference.py +++ /dev/null @@ -1,266 +0,0 @@ -import torch -from task import input_t, output_t - - -_FACTOR_RTOL_FACTOR = 20.0 -_ORTH_RTOL_FACTOR = 100.0 - - -def _apply_column_scaling(a: torch.Tensor, cond: int) -> torch.Tensor: - # `cond` is a deterministic dynamic-range knob, not an exact condition number. - if cond: - n = a.shape[-1] - scales = torch.logspace(0.0, -float(cond), n, device=a.device, dtype=torch.float32) - return a * scales - return a.contiguous() - - -def _band_mask(n: int, bandwidth: int, device: torch.device) -> torch.Tensor: - idx = torch.arange(n, device=device) - return (idx[:, None] - idx[None, :]).abs() <= bandwidth - - -# Per-matrix conditioning profiles drawn for the "mixed" case. "dense" is the -# well-conditioned majority; the rest are the ill-conditioned stress structures. -_MIXED_PROFILES = ("dense", "rankdef", "nearrank", "clustered", "band", "rowscale", "nearcollinear") -# Relative sampling weights (normalized by torch.multinomial); dense ~= 50%. -_MIXED_WEIGHTS = (6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) - - -def _apply_case(a: torch.Tensor, case: str, cond: int, gen: torch.Generator) -> torch.Tensor: - # Apply one conditioning profile to an already-drawn base batch `a` of shape - # (m, n, n), drawing any case-specific extra randomness from `gen`. Factored - # out of generate_input so the homogeneous cases and the per-matrix "mixed" - # case share a single implementation. The draw order (base first in the - # caller, then the case extras here) matches the original code, so every - # homogeneous case produces bit-for-bit identical data to before. - m, n = a.shape[0], a.shape[-1] - device = a.device - if case == "dense": - a = _apply_column_scaling(a, cond) - elif case == "upper": - diag_boost = torch.linspace(1.0, 0.25, n, device=device, dtype=torch.float32) - a = torch.triu(a) - a.diagonal(dim1=-2, dim2=-1).add_(diag_boost) - a = _apply_column_scaling(a, cond) - elif case == "diagonal": - diag = torch.randn((m, n), device=device, dtype=torch.float32, generator=gen) - diag = diag.sign().clamp(min=0.0).mul(2.0).sub(1.0) * torch.logspace( - 0.0, -float(max(cond, 2)), n, device=device, dtype=torch.float32 - ) - a = torch.diag_embed(diag) - elif case == "rankdef": - rank = max(1, (3 * n) // 4) - a[:, :, rank:] = 0.0 - a = _apply_column_scaling(a, cond) - elif case == "nearrank": - rank = max(1, (3 * n) // 4) - tail = n - rank - if tail > 0: - noise = torch.randn( - (m, n, tail), device=device, dtype=torch.float32, generator=gen - ) - a[:, :, rank:] = a[:, :, :tail] + 1.0e-5 * noise - a = _apply_column_scaling(a, cond) - elif case == "clustered": - scales = torch.ones((n,), device=device, dtype=torch.float32) - scales[n // 2 :] = 4.0 * torch.finfo(torch.float32).eps - if n >= 8: - lo = max(0, n // 2 - 2) - hi = min(n, n // 2 + 2) - scales[lo:hi] = torch.sqrt(torch.tensor(torch.finfo(torch.float32).eps, device=device)) - a = a * scales - elif case == "band": - bandwidth = max(2, min(32, n // 32)) - a = a * _band_mask(n, bandwidth, device) - diag_boost = torch.linspace(1.0, 0.5, n, device=device, dtype=torch.float32) - a.diagonal(dim1=-2, dim2=-1).add_(diag_boost) - a = _apply_column_scaling(a, cond) - elif case == "nearcollinear": - base = torch.randn((m, n, 1), device=device, dtype=torch.float32, generator=gen) - noise = torch.randn((m, n, n), device=device, dtype=torch.float32, generator=gen) - a = base.expand(m, n, n) + 1.0e-4 * noise - a = _apply_column_scaling(a, cond) - elif case == "rowscale": - row_cond = max(cond, 4) - scales = torch.logspace(0.0, -float(row_cond), n, device=device, dtype=torch.float32) - a = scales.reshape(1, n, 1) * a - else: - raise ValueError(f"unknown QR test case: {case}") - return a - - -def _generate_mixed(a: torch.Tensor, cond: int, gen: torch.Generator) -> torch.Tensor: - # Heterogeneous batch: assign each matrix an independent conditioning profile - # at a RANDOM position in the batch (seeded, so still deterministic), so - # well- and ill-conditioned matrices are interleaved rather than uniform - # across the batch. This matches the real optimizer-statistics regime (the - # per-layer / per-block factors have wildly different conditioning) and it - # removes the loophole where a kernel samples a few matrices, concludes the - # whole batch is well-conditioned, and routes it all to a fast path that is - # only numerically valid for well-conditioned inputs. With a mix present, - # passing the correctness gate requires handling each matrix on its merits. - m = a.shape[0] - device = a.device - weights = torch.tensor(_MIXED_WEIGHTS, dtype=torch.float32, device=device) - labels = torch.multinomial(weights, m, replacement=True, generator=gen) - # Guarantee both a well-conditioned and an ill-conditioned matrix are present. - # (Only relevant for tiny batches; large batches get both with high prob.) - if m >= 2: - is_dense = labels == 0 - if not bool(is_dense.any()): - labels[int(torch.randint(0, m, (1,), device=device, generator=gen))] = 0 - elif bool(is_dense.all()): - pos = int(torch.randint(0, m, (1,), device=device, generator=gen)) - labels[pos] = int(torch.randint(1, len(_MIXED_PROFILES), (1,), device=device, generator=gen)) - # Process profiles in fixed order over the present labels so the RNG draws - # inside _apply_case are deterministic for a given seed. - for k, prof in enumerate(_MIXED_PROFILES): - mask = labels == k - if bool(mask.any()): - a[mask] = _apply_case(a[mask], prof, cond, gen) - return a - - -def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t: - assert batch > 0, "batch must be positive" - assert n > 0, "n must be positive" - assert cond >= 0, "cond must be non-negative" - - device = "cuda" if torch.cuda.is_available() else "cpu" - gen = torch.Generator(device=device) - gen.manual_seed(seed) - - case = case.lower() - a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen) - - if case == "mixed": - a = _generate_mixed(a, cond, gen) - else: - a = _apply_case(a, case, cond, gen) - - return a.contiguous() - - -def ref_kernel(data: input_t) -> output_t: - # Starter/reference path: correctness first; submissions compete on speed. - return torch.geqrf(data) - - -def _property_rtol(n: int, factor: float) -> float: - eps = torch.finfo(torch.float32).eps - return factor * max(n, 1) * eps - - -def _scaled_residual( - residual: torch.Tensor, - scale: torch.Tensor, - n: int, -) -> torch.Tensor: - eps = torch.finfo(torch.float32).eps - return residual / (eps * max(n, 1) * scale.clamp_min(1e-30)) - - -def _matrix_l1_norm(value: torch.Tensor) -> torch.Tensor: - return torch.linalg.matrix_norm(value.double(), ord=1, dim=(-2, -1)) - - -def _check_tensor(name: str, value: torch.Tensor, shape: tuple[int, ...], device: torch.device) -> str | None: - if not isinstance(value, torch.Tensor): - return f"{name} must be a torch.Tensor" - if value.shape != shape: - return f"{name} shape must be {shape}, got {tuple(value.shape)}" - if value.dtype != torch.float32: - return f"{name} dtype must be torch.float32, got {value.dtype}" - if value.device != device: - return f"{name} must be on {device}, got {value.device}" - if not torch.isfinite(value).all().item(): - return f"{name} contains NaN or Inf" - return None - - -def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: - a = data - batch, n, _ = a.shape - factor_rtol = _property_rtol(n, _FACTOR_RTOL_FACTOR) - orth_rtol = _property_rtol(n, _ORTH_RTOL_FACTOR) - - if not isinstance(output, tuple) or len(output) != 2: - return False, "output must be a tuple `(H, tau)`" - - h, tau = output - error = _check_tensor("H", h, (batch, n, n), a.device) - if error is not None: - return False, error - error = _check_tensor("tau", tau, (batch, n), a.device) - if error is not None: - return False, error - - q = torch.linalg.householder_product(h, tau) - r = torch.triu(h) - if not torch.isfinite(q).all().item(): - return False, "Q materialized from `(H, tau)` contains NaN or Inf" - if not torch.isfinite(r).all().item(): - return False, "R extracted from `triu(H)` contains NaN or Inf" - - a_check = a.double() - q_check = q.double() - r_check = r.double() - projected = q_check.transpose(-1, -2) @ a_check - if not torch.isfinite(projected).all().item(): - return False, "Q.T @ A contains NaN or Inf" - - factor_residual = _matrix_l1_norm(r_check - projected) - factor_scale = _matrix_l1_norm(a_check) - factor_allowed = factor_rtol * factor_scale - factor_scaled = _scaled_residual(factor_residual, factor_scale, n) - if not torch.isfinite(factor_scaled).all().item(): - return False, "R - Q.T @ A residual produced NaN or Inf" - factor_failed = factor_residual > factor_allowed - if bool(factor_failed.any().item()): - worst = int(factor_scaled.argmax().item()) - return False, ( - "R - Q.T @ A is too large: " - f"matrix={worst}, residual={factor_residual[worst].item():.3g}, " - f"allowed={factor_allowed[worst].item():.3g}, " - f"scaled={factor_scaled[worst].item():.3g}" - ) - - eye = torch.eye(n, device=a.device, dtype=torch.float64).expand(batch, n, n) - qtq = q_check.transpose(-1, -2) @ q_check - if not torch.isfinite(qtq).all().item(): - return False, "Q.T @ Q contains NaN or Inf" - orth_residual = _matrix_l1_norm(qtq - eye).amax() - orth_scale = _matrix_l1_norm(eye).amax() - orth_allowed = orth_rtol * orth_scale - orth_scaled = _scaled_residual(orth_residual, orth_scale, n) - if not torch.isfinite(orth_scaled).all().item(): - return False, "Q.T @ Q residual produced NaN or Inf" - if orth_residual.item() > orth_allowed.item(): - return False, ( - "Q is not orthogonal enough: " - f"residual={orth_residual.item():.3g}, allowed={orth_allowed.item():.3g}, " - f"scaled={orth_scaled.item():.3g}" - ) - - lower = torch.tril(projected, diagonal=-1) - tri_residual = _matrix_l1_norm(lower).amax() - tri_scale = _matrix_l1_norm(a_check).amax() - tri_scaled = _scaled_residual(tri_residual, tri_scale, n) - - recon = q_check @ r_check - if not torch.isfinite(recon).all().item(): - return False, "Q @ R contains NaN or Inf" - recon_residual = _matrix_l1_norm(recon - a_check).amax() - recon_scale = _matrix_l1_norm(a_check).amax() - recon_scaled = _scaled_residual(recon_residual, recon_scale, n) - - return True, ( - f"factor_rtol={factor_rtol:.3g}; " - f"orth_rtol={orth_rtol:.3g}; " - f"scaled_factor_residual={factor_scaled.amax().item():.3g}; " - f"scaled_reconstruction_residual={recon_scaled.item():.3g}; " - f"scaled_triangular_residual={tri_scaled.item():.3g}; " - f"scaled_orthogonality_residual={orth_scaled.item():.3g}; " - f"batch={batch}; n={n}" - ) diff --git a/problems/linalg/qr_v2/submission.py b/problems/linalg/qr_v2/submission.py deleted file mode 100644 index ac92e0ac..00000000 --- a/problems/linalg/qr_v2/submission.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - return torch.geqrf(data) diff --git a/problems/linalg/qr_v2/task.py b/problems/linalg/qr_v2/task.py deleted file mode 100644 index e0547dcc..00000000 --- a/problems/linalg/qr_v2/task.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -from typing import NotRequired, TypeVar, TypedDict - -input_t = TypeVar("input_t", bound=torch.Tensor) -output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor]) - - -class TestSpec(TypedDict): - batch: int - n: int - cond: int - seed: int - case: NotRequired[str] diff --git a/problems/linalg/qr_v2/task.yml b/problems/linalg/qr_v2/task.yml deleted file mode 100644 index 0da22d88..00000000 --- a/problems/linalg/qr_v2/task.yml +++ /dev/null @@ -1,121 +0,0 @@ -# name: qr_v2 - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../../pmpp_v2/utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "eval.py"} - -lang: "py" - -description: | - Implement batched square compact-Householder QR factorization. - - Input is `A`, a `batch x n x n` CUDA tensor in `torch.float32`. - - Return `(H, tau)` in the same compact Householder convention as - `torch.geqrf(A)`. `H` is a `batch x n x n` FP32 tensor containing `R` in its - upper triangle and Householder vectors below the diagonal. `tau` is a - `batch x n` FP32 tensor containing reflector coefficients. The checker - materializes `Q = torch.linalg.householder_product(H, tau)`, uses - `R_factor = triu(H)`, and validates the LAPACK-style QR factorization - residual `R_factor - Q.T @ A` and orthogonality of `Q`. Since `R_factor` - is extracted with `triu`, triangularity is part of the factorization check: - if `Q.T @ A` has meaningful lower-triangular leakage, then it cannot match - `R_factor`. The checker reports that lower-triangular leakage and the - reconstruction residual as diagnostics. - - This shape set targets optimizer-style matrix statistics where gradients are - viewed as `[for_each..., basis_dim, contracted_dim]`, statistics are formed as - `G @ G.T`, and QR is run on square `basis_dim x basis_dim` matrices. Batched - `512 x 512` is especially important, while `1024`, `2048`, and `4096` cover - larger square factors. - - Test and benchmark specs include a `cond` field. In this task `cond` is a - deterministic input-scaling knob, not an exact requested condition number: - dense cases multiply columns by `logspace(0, -cond, n)`, so larger `cond` - creates a wider dynamic range across columns. Some stress cases use their own - structure, such as rank-deficient, near-rank-deficient, banded, row-scaled, - near-collinear, upper-triangular, or clustered-scale inputs. - - The `mixed` case builds a heterogeneous batch: each matrix is independently - assigned a conditioning profile (a well-conditioned dense majority interleaved - with the ill-conditioned stress structures above) at a random position in the - batch. This mirrors the real optimizer-statistics regime, where the per-layer - or per-block factors batched into one call have widely varying conditioning, - rather than all sharing one structure. The benchmark set (not just the test - set) now includes both `mixed` batches and fully ill-conditioned homogeneous - batches, so conditioning robustness is ranked, not only gated: an - implementation cannot inspect a few matrices, decide the whole batch is - well-conditioned, and route it to a path that is only valid for well-conditioned - inputs, and the runtime cost of the accurate path on hard inputs is part of the - score. Each matrix must be factored correctly on its own merits. - - Correctness is a hard gate against the original FP32 input and the FP32 - `torch.geqrf` compact-factor contract. Low-bit FP16, FP8, or NVFP4 work is - allowed only as an internal implementation strategy: returned factors must - still be FP32 and must satisfy the same QR invariants as an FP32 - factorization. Residuals are measured in FP64 to reduce checker noise, but - the target tolerance is still FP32 accuracy. The numerical property tolerance - is purely relative, with no QR `atol`. - The hard gates are the LAPACK-style factor residual, which uses - `rtol = 20 * n * eps32`, and orthogonality, which uses - `rtol = 100 * n * eps32`, each applied to the corresponding matrix L1 norm. - Triangularity is reported as lower-triangular leakage in `Q.T @ A` and is - already implied by the factor residual against `triu(H)`. - - Among passing submissions, ranking is by runtime using the geometric mean of - benchmark cases. We will also celebrate notable submissions beyond the main - leaderboard: the fastest, the most elegant, and the strangest working kernels. - -config: - main: "eval.py" - -templates: - Python: "submission.py" - -test_timeout: 240 -benchmark_timeout: 480 -ranked_timeout: 900 -ranking_by: "geom" -gpus: - - B200 - -tests: - - {"batch": 20, "n": 32, "cond": 1, "seed": 53124} - - {"batch": 40, "n": 176, "cond": 1, "seed": 3321} - - {"batch": 40, "n": 352, "cond": 1, "seed": 1200} - - {"batch": 16, "n": 512, "cond": 2, "seed": 32523} - - {"batch": 4, "n": 1024, "cond": 2, "seed": 4327} - - {"batch": 1, "n": 4096, "cond": 1, "seed": 75342} - - {"batch": 16, "n": 512, "cond": 4, "seed": 32524, "case": "dense"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32525, "case": "rankdef"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32526, "case": "clustered"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32527, "case": "band"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32528, "case": "rowscale"} - - {"batch": 16, "n": 512, "cond": 0, "seed": 32529, "case": "nearcollinear"} - - {"batch": 4, "n": 1024, "cond": 4, "seed": 4328, "case": "dense"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4329, "case": "rankdef"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4330, "case": "nearrank"} - - {"batch": 4, "n": 1024, "cond": 0, "seed": 4331, "case": "clustered"} - - {"batch": 2, "n": 2048, "cond": 2, "seed": 224466, "case": "dense"} - - {"batch": 2, "n": 2048, "cond": 0, "seed": 224467, "case": "rankdef"} - - {"batch": 1, "n": 4096, "cond": 0, "seed": 75343, "case": "upper"} - - {"batch": 16, "n": 512, "cond": 2, "seed": 32530, "case": "mixed"} - - {"batch": 4, "n": 1024, "cond": 2, "seed": 4332, "case": "mixed"} - - {"batch": 2, "n": 2048, "cond": 2, "seed": 224468, "case": "mixed"} - -benchmarks: - - {"batch": 20, "n": 32, "cond": 1, "seed": 43214} - - {"batch": 40, "n": 176, "cond": 1, "seed": 423011} - - {"batch": 40, "n": 352, "cond": 1, "seed": 123456} - - {"batch": 640, "n": 512, "cond": 2, "seed": 1029} - - {"batch": 60, "n": 1024, "cond": 2, "seed": 75342} - - {"batch": 8, "n": 2048, "cond": 1, "seed": 224466} - - {"batch": 2, "n": 4096, "cond": 1, "seed": 32412} - - {"batch": 640, "n": 512, "cond": 2, "seed": 770001, "case": "mixed"} - - {"batch": 60, "n": 1024, "cond": 2, "seed": 770002, "case": "mixed"} - - {"batch": 640, "n": 512, "cond": 0, "seed": 770003, "case": "rankdef"} - - {"batch": 640, "n": 512, "cond": 0, "seed": 770004, "case": "clustered"} - - {"batch": 60, "n": 1024, "cond": 0, "seed": 770005, "case": "nearrank"} diff --git a/problems/nvidia.yaml b/problems/nvidia.yaml index f38e8bda..64ca9f71 100644 --- a/problems/nvidia.yaml +++ b/problems/nvidia.yaml @@ -7,27 +7,6 @@ description: "NVIDIA Blackwell NVFP4 Kernel Hackathon" problems: - directory: nvidia/nvfp4_gemv name: nvfp4_gemv - deadline: "2025-11-29 6:59" + deadline: "2025-11-28" gpus: - NVIDIA - - directory: nvidia/nvfp4_gemm - name: nvfp4_gemm - deadline: "2025-12-21 7:59" - gpus: - - NVIDIA - - directory: nvidia/nvfp4_dual_gemm - name: nvfp4_dual_gemm - deadline: "2026-01-20 7:59" - gpus: - - NVIDIA - - directory: nvidia/modal_nvfp4_dual_gemm - name: modal_nvfp4_dual_gemm - deadline: "2026-01-20 7:59" - gpus: - - B200 - - directory: nvidia/nvfp4_group_gemm - name: nvfp4_group_gemm - deadline: "2026-02-21 7:30" - gpus: - - B200 - - NVIDIA diff --git a/problems/nvidia/eval.py b/problems/nvidia/eval.py index 252f35e4..ed370157 100644 --- a/problems/nvidia/eval.py +++ b/problems/nvidia/eval.py @@ -254,8 +254,8 @@ def _run_single_benchmark( del output durations.append(duration) - total_bm_duration = time.perf_counter_ns() - bm_start_time - if i > 1 and total_bm_duration > 1e8: # at least 2 runs, and at least 100 ms total time + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time stats = calculate_stats(durations) # stop if either # a) relative error dips below 0.1% @@ -453,11 +453,11 @@ def main(): return run_benchmarking(logger, pool, tests) if mode == "leaderboard": - run_single_benchmark(pool, tests[0], False, 1000, 5e8) + run_single_benchmark(pool, tests[0], False, 200, 1e7) logger.log("benchmark-count", len(tests)) passed = True for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 1000, 30e9) + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) logger.log(f"benchmark.{i}.spec", tests[i].spec) if isinstance(result, Stats): for field in dataclasses.fields(Stats): diff --git a/problems/nvidia/eval_better_bench.py b/problems/nvidia/eval_better_bench.py deleted file mode 100644 index 007781ed..00000000 --- a/problems/nvidia/eval_better_bench.py +++ /dev/null @@ -1,510 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math - -# Disable CuTe DSL file caching for more stable benchmarking -os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" - - -def _init_worker(): - """Initialize worker process with correct env vars.""" - os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" - - -from pathlib import Path -from typing import Any, Optional - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError -from torch.cuda.nvtx import range as nvtx_range - -from utils import set_seed, clear_l2_cache_large as clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - -NUM_ITERATIONS_PER_BENCHMARK = 50 - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except OpError as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - data_list = [] - # generate input data once - - for i in range(NUM_ITERATIONS_PER_BENCHMARK): - if "seed" in test.args: - test.args["seed"] += 42 - data = generate_input(**test.args) - data_list.append(data) - - check_copy = _clone_data(data_list) - - # first, one obligatory correctness check - outputs = [] - try: - for data in data_list: - output = custom_kernel(_clone_data(data)) - outputs.append(output) - except OpError as E: - return f"Encountered {E}" - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - torch.cuda.synchronize() - - outputs = [] - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for data in data_list: - output = custom_kernel(data) - outputs.append(output) - end_event.record() - torch.cuda.synchronize() - duration = ( - start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK - ) * 1e6 # Convert ms to ns - - if recheck: - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message - - durations.append(duration) - - total_bm_duration = time.perf_counter_ns() - bm_start_time - if ( - i > 1 and total_bm_duration > 1e8 - ): # at least 2 runs, and at least 100 ms total time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - - run_single_benchmark(pool, tests[0], False, 200, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 200, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_profile_torch(test: TestCase) -> str: - """ - Profiles a single benchmark using the torch profiler. - """ - from submission import custom_kernel - from torch.profiler import profile, ProfilerActivity - - with nvtx_range("generate input"): - data = generate_input(**test.args) - torch.cuda.synchronize() - - cloned = _clone_data(data) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - with nvtx_range("custom_kernel"): - submission_output = custom_kernel(cloned) - torch.cuda.synchronize() - - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def _run_single_profile_ncu(test: TestCase) -> str: - """ - Profiles a single benchmark using ncu. Note: this does not - invoke NCU; instead, it is expected that eval is launched - under NCU, and this function will rurnthe kernel excactly - once in the 'custom_kernel' nvtx range. - """ - from submission import custom_kernel - - with nvtx_range("generate input"): - data = generate_input(**test.args) - torch.cuda.synchronize() - - cloned = _clone_data(data) - with nvtx_range("custom_kernel"): - submission_output = custom_kernel(cloned) - torch.cuda.synchronize() - - return "" - - -def _combine_traces(traces: list["EventList"]) -> "EventList": - """ - Combine multiple event traces obtained from multiple (distributed) torch.profiler - activities. This function simply aggregates the data as like `prof.key_averages()`, - except over multiple traces. Most of this function is reimplemented - from `torch.autograd.profiler_util.EventList.key_averages()`. - """ - from torch.autograd.profiler_util import FunctionEventAvg, EventList - from collections import defaultdict - - def get_key(event) -> tuple[str, ...]: - return ( - str(event.key), - str(event.node_id), - str(event.device_type), - str(event.is_legacy), - str(event.is_user_annotation), - ) - - stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) - - for events in traces: - for event in events: - stats[get_key(event)].add(event) - - avg_list = EventList(stats.values()) - for event in avg_list: - event.stack = [] - event.input_shapes = "" - event.overload_name = "" - - return avg_list - - -def run_single_profile(test: TestCase, pool: multiprocessing.Pool) -> str: - """ - Runs a single profiling activity in another process. - """ - if bool(os.getenv("POPCORN_NCU", "0")): - return pool.apply(_run_single_profile_ncu, (test,)) - else: - return pool.apply(_run_single_profile_torch, (test,)) - - -def run_profiling( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test, pool) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1, initializer=_init_worker) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # Warmup all test shapes to ensure consistent benchmarking - for test in tests: - run_single_benchmark(pool, test, False, 1000, 5e8) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 1000, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, pool, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py deleted file mode 100644 index 09b52790..00000000 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ /dev/null @@ -1,529 +0,0 @@ -import base64 -import dataclasses -import multiprocessing -import re -import time -import os -import sys -import math - -# Disable CuTe DSL file caching for more stable benchmarking -os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" - - -def _init_worker(): - """Initialize worker process with correct env vars.""" - os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" - - -from pathlib import Path -from typing import Any, Optional - -import torch.cuda -from cutlass.cute.nvgpu.common import OpError -from cutlass._mlir.ir import MLIRError - -from torch.cuda.nvtx import range as nvtx_range - -from utils import set_seed, clear_l2_cache_large as clear_l2_cache - -try: - from task import TestSpec -except ImportError: - TestSpec = dict - -from reference import check_implementation, generate_input - -NUM_ITERATIONS_PER_BENCHMARK = 15 -UNSERIALIZABLE_EXCEPTIONS = (OpError, MLIRError) - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -def _combine(a: int, b: int) -> int: - # combine two integers into one: - # we need this to generate a secret seed based on the test-level seed and - # the global secret seed. - # the test-level seeds are public knowledge, and typically relatively small numbers, - # so we need to make sure they don't provide any useful info for the full seed. - # This Cantor construction ensures that if the secret seed is a large number, - # then so is the overall seed. - return int(a + (a + b) * (a + b + 1) // 2) - - -def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as E: - print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) - exit(113) - - tests = [] - lines = content.splitlines() - # Match key: value pairs where value can be: - # - a list like [1, 2, 3] (needed for group gemm which has per-group dimensions) - # - a tuple like (1, 2, 3) - # - an integer - # - an alphabetic string - match = r"\s*([a-zA-Z_]+)\s*:\s*(\[[^\]]*\]|\([^)]*\)|[a-zA-Z_]+|[+-]?[0-9]+)\s*" - for line in lines: - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - exit(113) - key = matched[1] - val = matched[2] - try: - val = int(val) - except ValueError: - # Try parsing as tuple/list (e.g., [1, 2, 3] for group gemm dimensions) - if (val.startswith("(") and val.endswith(")")) or ( - val.startswith("[") and val.endswith("]") - ): - try: - inner = val[1:-1].strip() - if inner: - val = tuple(int(x.strip()) for x in inner.split(",")) - else: - val = tuple() - except ValueError: - pass - - case[key] = val - tests.append(TestCase(spec=line, args=case)) - - if seed is not None: - for test in tests: - if "seed" in test.args: - test.args["seed"] = _combine(test.args["seed"], seed) - - return tests - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - - -def calculate_stats(durations: list[int]): - """ - Calculate statistical data from a list of durations. - - @param durations: A list of durations in nanoseconds. - @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. - """ - runs = len(durations) - total = sum(durations) - best = min(durations) - worst = max(durations) - - avg = total / runs - variance = sum(map(lambda x: (x - avg) ** 2, durations)) - std = math.sqrt(variance / (runs - 1)) - err = std / math.sqrt(runs) - - return Stats( - runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) - ) - - -def _clone_data(data): - """ - Recursively goes through data and clones all tensors. - """ - if isinstance(data, tuple): - return tuple(_clone_data(x) for x in data) - elif isinstance(data, list): - return [_clone_data(x) for x in data] - elif isinstance(data, dict): - return {k: _clone_data(v) for k, v in data.items()} - elif isinstance(data, torch.Tensor): - return data.clone() - else: - return data - - -def _run_single_test(test: TestCase): - """ - Runs a single test case. Do not call directly - """ - from submission import custom_kernel - - data = generate_input(**test.args) - torch.cuda.synchronize() - try: - submission_output = custom_kernel(_clone_data(data)) - - except UNSERIALIZABLE_EXCEPTIONS as E: - print(f"Encountered {E}", file=sys.stderr) - return False, str(E) - torch.cuda.synchronize() - return check_implementation(data, submission_output) - - -def run_single_test(pool: multiprocessing.Pool, test: TestCase): - """ - Runs a single test in another process. - """ - return pool.apply(_run_single_test, (test,)) - - -def run_testing( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes the actual test case code and checks for correctness. - - @param logger: A PopcornOutput object used for logging test results. - @param tests: A list of TestCase objects representing the test cases to be executed. - @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. - """ - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"test.{idx}.spec", test.spec) - good, message = run_single_test(pool, test) - if not good: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", message) - passed = False - else: - logger.log(f"test.{idx}.status", "pass") - if message: - logger.log(f"test.{idx}.message", message) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_benchmark( - test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float -) -> Stats | Any: - """ - Runs one benchmark. Do not call directly. - """ - from submission import custom_kernel - - durations = [] - data_list = [] - # generate input data once - - for i in range(NUM_ITERATIONS_PER_BENCHMARK): - if "seed" in test.args: - test.args["seed"] += 42 - data = generate_input(**test.args) - data_list.append(data) - - check_copy = _clone_data(data_list) - - # first, one obligatory correctness check - outputs = [] - try: - for data in data_list: - output = custom_kernel(_clone_data(data)) - outputs.append(output) - except UNSERIALIZABLE_EXCEPTIONS as E: - return f"Encountered {E}" - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message - - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. - - bm_start_time = time.perf_counter_ns() - for i in range(max_repeats): - torch.cuda.synchronize() - - outputs = [] - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for data in data_list: - output = custom_kernel(data) - outputs.append(output) - end_event.record() - torch.cuda.synchronize() - duration = ( - start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK - ) * 1e6 # Convert ms to ns - - if recheck: - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message - - durations.append(duration) - - total_bm_duration = time.perf_counter_ns() - bm_start_time - if ( - i > 1 and total_bm_duration > 1e8 - ): # at least 2 runs, and at least 100 ms total time - stats = calculate_stats(durations) - # stop if either - # a) relative error dips below 0.1% - # b) we exceed the total time limit for benchmarking the kernel - # c) we exceed 2 minutes of total wallclock time. - if ( - stats.err / stats.mean < 0.001 - or stats.mean * stats.runs > max_time_ns - or total_bm_duration > 120e9 - ): - break - - return calculate_stats(durations) - - -def run_single_benchmark( - pool: multiprocessing.Pool, - test: TestCase, - recheck: bool, - max_repeats: int, - max_time_ns: float, -): - """ - For a particular test case, check correctness (if applicable) and grab runtime results. - - @param pool: Process on which the benchmark will be launched. - @param test: TestCase object. - @param recheck: Flag for whether to explicitly check functional correctness. - @param max_repeats: Number of trials to repeat. - @param max_time_ns: Timeout time in nanoseconds. - @return: A Stats object for this particular benchmark case or an error if the test fails. - """ - return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) - - -def run_benchmarking( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - """ - Executes benchmarking code for a CUDA Kernel and logs runtimes. - - @param logger: A PopcornOutput object used for logging benchmark results. - @param pool: Process on which the benchmarks will be launched. - @param tests: A list of TestCase objects representing the test cases to be benchmarked. - @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. - """ - - run_single_benchmark(pool, tests[0], False, 100, 10e7) - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 100, 10e9) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) - else: - passed = False - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", result) - - if passed: - logger.log("check", "pass") - return 0 - else: - logger.log("check", "fail") - return 112 - - -def _run_single_profile_torch(test: TestCase) -> str: - """ - Profiles a single benchmark using the torch profiler. - """ - from submission import custom_kernel - from torch.profiler import profile, ProfilerActivity - - with nvtx_range("generate input"): - data = generate_input(**test.args) - torch.cuda.synchronize() - - cloned = _clone_data(data) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - with nvtx_range("custom_kernel"): - submission_output = custom_kernel(cloned) - torch.cuda.synchronize() - - return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) - - -def _run_single_profile_ncu(test: TestCase) -> str: - """ - Profiles a single benchmark using ncu. Note: this does not - invoke NCU; instead, it is expected that eval is launched - under NCU, and this function will rurnthe kernel excactly - once in the 'custom_kernel' nvtx range. - """ - from submission import custom_kernel - - with nvtx_range("generate input"): - data = generate_input(**test.args) - torch.cuda.synchronize() - - cloned = _clone_data(data) - with nvtx_range("custom_kernel"): - submission_output = custom_kernel(cloned) - torch.cuda.synchronize() - - return "" - - -def _combine_traces(traces: list["EventList"]) -> "EventList": - """ - Combine multiple event traces obtained from multiple (distributed) torch.profiler - activities. This function simply aggregates the data as like `prof.key_averages()`, - except over multiple traces. Most of this function is reimplemented - from `torch.autograd.profiler_util.EventList.key_averages()`. - """ - from torch.autograd.profiler_util import FunctionEventAvg, EventList - from collections import defaultdict - - def get_key(event) -> tuple[str, ...]: - return ( - str(event.key), - str(event.node_id), - str(event.device_type), - str(event.is_legacy), - str(event.is_user_annotation), - ) - - stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) - - for events in traces: - for event in events: - stats[get_key(event)].add(event) - - avg_list = EventList(stats.values()) - for event in avg_list: - event.stack = [] - event.input_shapes = "" - event.overload_name = "" - - return avg_list - - -def run_single_profile(test: TestCase, pool: multiprocessing.Pool) -> str: - """ - Runs a single profiling activity in another process. - """ - if bool(os.getenv("POPCORN_NCU", "0")): - return pool.apply(_run_single_profile_ncu, (test,)) - else: - return pool.apply(_run_single_profile_torch, (test,)) - - -def run_profiling( - logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] -): - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - logger.log(f"benchmark.{idx}.spec", test.spec) - report = run_single_profile(test, pool) - logger.log( - f"benchmark.{idx}.report", - base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), - ) - logger.log("check", "pass") - return 0 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - mode = sys.argv[1] - seed = os.getenv("POPCORN_SEED") - os.unsetenv("POPCORN_SEED") - seed = int(seed) if seed else None - set_seed(seed or 42) - - tests = get_test_cases(sys.argv[2], seed) - - with PopcornOutput(int(fd)) as logger: - import multiprocessing - - mp_context = multiprocessing.get_context("spawn") - with mp_context.Pool(1, initializer=_init_worker) as pool: - if mode == "test": - return run_testing(logger, pool, tests) - if mode == "benchmark": - return run_benchmarking(logger, pool, tests) - - if mode == "leaderboard": - # Warmup all test shapes to ensure consistent benchmarking - for test in tests: - run_single_benchmark(pool, test, False, 50, 5e8) - logger.log("benchmark-count", len(tests)) - passed = True - for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) - logger.log(f"benchmark.{i}.spec", tests[i].spec) - if isinstance(result, Stats): - for field in dataclasses.fields(Stats): - logger.log( - f"benchmark.{i}.{field.name}", - getattr(result, field.name), - ) - else: - passed = False - logger.log(f"benchmark.{i}.status", "fail") - logger.log( - f"benchmark.{i}.error", str(result) - ) # TODO: Make sure result implements __str__? - break - - logger.log("check", "pass" if passed else "fail") - elif mode == "profile": - run_profiling(logger, pool, tests) - else: - # TODO: Implement script mode - return 2 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/reference.py b/problems/nvidia/modal_nvfp4_dual_gemm/reference.py deleted file mode 100644 index 95c6aacd..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/reference.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch -from task import input_t, output_t -from utils import make_match_reference - -# Scaling factor vector size -sf_vec_size = 16 - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - -# Helper function to convert scale factor tensor to blocked format -def to_blocked(input_matrix): - rows, cols = input_matrix.shape - - # Please ensure rows and cols are multiples of 128 and 4 respectively - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - - padded = input_matrix - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - - return rearranged.flatten() - - -def ref_kernel( - data: input_t, -) -> output_t: - """ - PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation, - C = silu(A @ B1) * (A @ B2). - """ - a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data - - # Get dimensions from MxNxL layout - m, n, l = c_ref.shape - - # Call torch._scaled_mm to compute the GEMV result - ref1 = torch.empty( - (l, m, n), - dtype=torch.float32, - device="cuda", - ).permute(1, 2, 0) - ref2 = torch.empty( - (l, m, n), - dtype=torch.float32, - device="cuda", - ).permute(1, 2, 0) - for l_idx in range(l): - # Convert the scale factor tensor to blocked format - scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) - scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx]) - scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx]) - # (m, k) @ (n, k).T -> (m, n) - res1 = torch._scaled_mm( - a_ref[:, :, l_idx], - b1_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b1.cuda(), - bias=None, - out_dtype=torch.float32, - ) - ref1[:, :, l_idx] = res1 - - res2 = torch._scaled_mm( - a_ref[:, :, l_idx], - b2_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b2.cuda(), - bias=None, - out_dtype=torch.float32, - ) - ref2[:, :, l_idx] = res2 - # Do silu on the first GEMM result and multiply with the second GEMM result - c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16) - return c_ref - - -def generate_input( - m: int, - n: int, - k: int, - l: int, - seed: int, -): - """ - Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation, - C = silu(A @ B1) * (A @ B2). - - Args: - m: Number of rows in matrix A - n: Number of columns in matrix B1 and B2 - k: Number of columns in A and rows of B1 and B2 - l: Batch size - seed: Random seed for reproducibility - - Returns: - Tuple of (a, b, scale_a, scale_b, c) where: - a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - c: [m, n, l] - Output matrix in torch.float16 data type - """ - torch.manual_seed(seed) - - def create_fp4_tensors(l, mn, k): - # generate uint8 tensor, then convert to float4e2m1fn_x2 data type - # generate all bit patterns - ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda") - - # for each nibble, only keep the sign bit and 2 LSBs - # the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5] - ref_i8 = ref_i8 & 0b1011_1011 - - return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2) - - # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type - a_ref = create_fp4_tensors(l, m, k) - b1_ref = create_fp4_tensors(l, n, k) - b2_ref = create_fp4_tensors(l, n, k) - a_ref = a_ref.view(torch.float4_e2m1fn_x2) - b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) - b2_ref = b2_ref.view(torch.float4_e2m1fn_x2) - - # Create float16 output tensor - c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( - 1, 2, 0 - ) - - # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. The customized data layout can be found in: - # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout - def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, sf_k, l) on CPU. - ref_shape = (l, mn, sf_k) - ref_permute_order = (1, 2, 0) - # Init with fp32 tensor in [0,1), then convert to float8_e4m3fn - ref_f8_random_fp32 = torch.rand(ref_shape, dtype=torch.float32, device='cuda') - ref_f8_torch_tensor = ref_f8_random_fp32.to(dtype=torch.float8_e4m3fn) - # permute to match ref_permute_order - ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) - - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - - # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout - # Which is needed by the CuTe customized kernel - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.empty(mma_shape, dtype=torch.int8, device='cuda') - reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) - - # GPU-side vectorized reordering (replaces slow CPU nested loops) - # Create index grids for all dimensions - i_idx = torch.arange(mn, device='cuda') - j_idx = torch.arange(sf_k, device='cuda') - b_idx = torch.arange(l, device='cuda') - - # Create meshgrid for all combinations of (i, j, b) - i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') - - # Calculate target indices in vectorized manner - mm = i_grid // (atom_m[0] * atom_m[1]) - mm32 = i_grid % atom_m[0] - mm4 = (i_grid % 128) // atom_m[0] - kk = j_grid // atom_k - kk4 = j_grid % atom_k - - # Perform the reordering with advanced indexing (all on GPU) - reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] - - return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor - - sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) - sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - - return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/submission.py b/problems/nvidia/modal_nvfp4_dual_gemm/submission.py deleted file mode 100644 index 739cc5a0..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/submission.py +++ /dev/null @@ -1,957 +0,0 @@ -from torch._higher_order_ops.torchbind import call_torchbind_fake -import cuda.bindings.driver as cuda - -import torch -from task import input_t, output_t - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass.cute.runtime import make_ptr - -# Kernel configuration parameters -# Tile sizes for M, N, K dimensions -mma_tiler_mnk= (128, 128, 256) -# Shape of the K dimension for the MMA instruction -mma_inst_shape_k = 64 -# FP4 data type for A and B -ab_dtype = cutlass.Float4E2M1FN -# FP8 data type for scale factors -sf_dtype = cutlass.Float8E4M3FN -# FP16 output type -c_dtype = cutlass.Float16 -# Scale factor block size (16 elements share one scale) -sf_vec_size = 16 -# Number of threads per CUDA thread block -threads_per_cta = 128 -# Stage numbers of shared memory and tmem -num_acc_stage = 1 -num_ab_stage = 1 -# Total number of columns in tmem -num_tmem_alloc_cols = 512 - - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# GPU device kernel -@cute.kernel -def kernel( - tiled_mma: cute.TiledMma, - tma_atom_a: cute.CopyAtom, - mA_mkl: cute.Tensor, - tma_atom_b1: cute.CopyAtom, - mB_nkl1: cute.Tensor, - tma_atom_b2: cute.CopyAtom, - mB_nkl2: cute.Tensor, - tma_atom_sfa: cute.CopyAtom, - mSFA_mkl: cute.Tensor, - tma_atom_sfb1: cute.CopyAtom, - mSFB_nkl1: cute.Tensor, - tma_atom_sfb2: cute.CopyAtom, - mSFB_nkl2: cute.Tensor, - mC_mnl: cute.Tensor, - a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, - sfa_smem_layout_staged: cute.Layout, - sfb_smem_layout_staged: cute.Layout, - num_tma_load_bytes: cutlass.Constexpr[int], - epilogue_op: cutlass.Constexpr = lambda x: x - * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), -): - """ - GPU device kernel performing the batched GEMM computation. - """ - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - tidx = cute.arch.thread_idx() - - # - # Setup cta/thread coordinates - # - # Coords inside cluster - bidx, bidy, bidz = cute.arch.block_idx() - mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) - - # Coords outside cluster - cta_coord = (bidx, bidy, bidz) - mma_tile_coord_mnl = ( - cta_coord[0] // cute.size(tiled_mma.thr_id.shape), - cta_coord[1], - cta_coord[2], - ) - # Coord inside cta - tidx, _, _ = cute.arch.thread_idx() - - # - # Define shared storage for kernel - # - @cute.struct - class SharedStorage: - ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] - acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] - tmem_holding_buf: cutlass.Int32 - - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - # (MMA, MMA_M, MMA_K, STAGE) - sA = smem.allocate_tensor( - element_type=ab_dtype, - layout=a_smem_layout_staged.outer, - byte_alignment=128, - swizzle=a_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB1 = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB2 = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_M, MMA_K, STAGE) - sSFA = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfa_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB1 = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB2 = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - - # - # Initialize mainloop ab_pipeline, acc_pipeline and their states - # - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.ab_mbar_ptr.data_ptr(), - num_stages=num_ab_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=ab_pipeline_consumer_group, - tx_count=num_tma_load_bytes, - ).make_participants() - acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( - barrier_storage=storage.acc_mbar_ptr.data_ptr(), - num_stages=num_acc_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, - threads_per_cta, - ), - ).make_participants() - - # - # Local_tile partition global tensors - # - # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl1 = cute.local_tile( - mB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl2 = cute.local_tile( - mB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - gSFB_nkl1 = cute.local_tile( - mSFB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl2 = cute.local_tile( - mSFB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) - ) - k_tile_cnt = cute.size(gA_mkl, mode=[3]) - - # - # Partition global tensor for TiledMMA_A/B/SFA/SFB/C - # - # (MMA, MMA_M, MMA_K, RestK) - thr_mma = tiled_mma.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB1 = thr_mma.partition_B(gB_nkl1) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB2 = thr_mma.partition_B(gB_nkl2) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB1 = thr_mma.partition_B(gSFB_nkl1) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB2 = thr_mma.partition_B(gSFB_nkl2) - # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) - tCgC = thr_mma.partition_C(gC_mnl) - - # - # Partition global/shared tensor for TMA load A/B/SFA/SFB - # - # TMA Partition_S/D for A - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsA, tAgA = cpasync.tma_partition( - tma_atom_a, - 0, - cute.make_layout(1), - cute.group_modes(sA, 0, 3), - cute.group_modes(tCgA, 0, 3), - ) - # TMA Partition_S/D for B1 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB1, tBgB1 = cpasync.tma_partition( - tma_atom_b1, - 0, - cute.make_layout(1), - cute.group_modes(sB1, 0, 3), - cute.group_modes(tCgB1, 0, 3), - ) - # TMA Partition_S/D for B2 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB2, tBgB2 = cpasync.tma_partition( - tma_atom_b2, - 0, - cute.make_layout(1), - cute.group_modes(sB2, 0, 3), - cute.group_modes(tCgB2, 0, 3), - ) - # TMA Partition_S/D for SFA - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsSFA, tAgSFA = cpasync.tma_partition( - tma_atom_sfa, - 0, - cute.make_layout(1), - cute.group_modes(sSFA, 0, 3), - cute.group_modes(tCgSFA, 0, 3), - ) - tAsSFA = cute.filter_zeros(tAsSFA) - tAgSFA = cute.filter_zeros(tAgSFA) - # TMA Partition_S/D for SFB1 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB1, tBgSFB1 = cpasync.tma_partition( - tma_atom_sfb1, - 0, - cute.make_layout(1), - cute.group_modes(sSFB1, 0, 3), - cute.group_modes(tCgSFB1, 0, 3), - ) - tBsSFB1 = cute.filter_zeros(tBsSFB1) - tBgSFB1 = cute.filter_zeros(tBgSFB1) - # TMA Partition_S/D for SFB2 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB2, tBgSFB2 = cpasync.tma_partition( - tma_atom_sfb2, - 0, - cute.make_layout(1), - cute.group_modes(sSFB2, 0, 3), - cute.group_modes(tCgSFB2, 0, 3), - ) - tBsSFB2 = cute.filter_zeros(tBsSFB2) - tBgSFB2 = cute.filter_zeros(tBgSFB2) - - # - # Partition shared/tensor memory tensor for TiledMMA_A/B/C - # - # (MMA, MMA_M, MMA_K, STAGE) - tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB1 = tiled_mma.make_fragment_B(sB1) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB2 = tiled_mma.make_fragment_B(sB2) - # (MMA, MMA_M, MMA_N) - acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) - # (MMA, MMA_M, MMA_N) - tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) - - # - # Alloc tensor memory buffer - # Make ACC1 and ACC2 tmem tensor - # ACC1 += A @ B1 - # ACC2 += A @ B2 - # - tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=1, - num_threads=threads_per_cta, - ) - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=tmem_alloc_barrier, - ) - tmem.allocate(num_tmem_alloc_cols) - tmem.wait_for_alloc() - acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) - tCtAcc1 = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - acc_tmem_ptr1 = cute.recast_ptr( - acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc1), - dtype=cutlass.Float32, - ) - tCtAcc2 = cute.make_tensor(acc_tmem_ptr1, tCtAcc_fake.layout) - - # - # Make SFA/SFB1/SFB2 tmem tensor - # - # SFA tmem layout: (MMA, MMA_M, MMA_K) - tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), - ) - # Get SFA tmem ptr - sfa_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2), - dtype=sf_dtype, - ) - tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) - - # SFB1, SFB2 tmem layout: (MMA, MMA_N, MMA_K) - tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), - ) - # Get SFB1 tmem ptr - sfb_tmem_ptr1 = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), - dtype=sf_dtype, - ) - tCtSFB1 = cute.make_tensor(sfb_tmem_ptr1, tCtSFB_layout) - # Get SFB2 tmem ptr - sfb_tmem_ptr2 = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA) - + tcgen05.find_tmem_tensor_col_offset(tCtSFB1), - dtype=sf_dtype, - ) - tCtSFB2 = cute.make_tensor(sfb_tmem_ptr2, tCtSFB_layout) - - # - # Partition for S2T copy of SFA/SFB1/SFB2 - # - # Make S2T CopyAtom - copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), - sf_dtype, - ) - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFA_compact = cute.filter_zeros(sSFA) - # (MMA, MMA_MN, MMA_K) - tCtSFA_compact = cute.filter_zeros(tCtSFA) - tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) - thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) - - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact = cute.filter_zeros(sSFB1) - # (MMA, MMA_MN, MMA_K) - tCtSFB1_compact = cute.filter_zeros(tCtSFB1) - tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB1_compact) - thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB1_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB1_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB1_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB1_compact) - - # SFB2 S2T copy and partition - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact = cute.filter_zeros(sSFB2) - # (MMA, MMA_MN, MMA_K) - tCtSFB2_compact = cute.filter_zeros(tCtSFB2) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB2_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB2_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB2_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB2_compact) - - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), RestK) - tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB1 = tBgB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB2 = tBgB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB1 = tBgSFB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB2 = tBgSFB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - - # - # Execute Data copy and Math computation in the k_tile loop - # - if warp_idx == 0: - # Wait for accumulator buffer empty - acc_empty = acc_producer.acquire_and_advance() - # Set ACCUMULATE field to False for the first k_tile iteration - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # Execute k_tile loop - for k_tile in range(k_tile_cnt): - # Wait for AB buffer empty - ab_empty = ab_producer.acquire_and_advance() - - # TMA load A/B1/B2/SFA/SFB1/SFB2 to shared memory - cute.copy( - tma_atom_a, - tAgA[(None, ab_empty.count)], - tAsA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_b1, - tBgB1[(None, ab_empty.count)], - tBsB1[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_b2, - tBgB2[(None, ab_empty.count)], - tBsB2[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfa, - tAgSFA[(None, ab_empty.count)], - tAsSFA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfb1, - tBgSFB1[(None, ab_empty.count)], - tBsSFB1[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfb2, - tBgSFB2[(None, ab_empty.count)], - tBsSFB2[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - - # Wait for AB buffer full - ab_full = ab_consumer.wait_and_advance() - - # Copy SFA/SFB1/SFB2 to tmem - s2t_stage_coord = (None, None, None, None, ab_full.index) - tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] - tCsSFB1_compact_s2t_staged = tCsSFB1_compact_s2t[s2t_stage_coord] - tCsSFB2_compact_s2t_staged = tCsSFB2_compact_s2t[s2t_stage_coord] - cute.copy( - tiled_copy_s2t_sfa, - tCsSFA_compact_s2t_staged, - tCtSFA_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB1_compact_s2t_staged, - tCtSFB1_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB2_compact_s2t_staged, - tCtSFB2_compact_s2t, - ) - - # tCtAcc1 += tCrA * tCrSFA * tCrB1 * tCrSFB1 - # tCtAcc2 += tCrA * tCrSFA * tCrB2 * tCrSFB2 - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = ( - None, - None, - kblock_idx, - ab_full.index, - ) - - # Set SFA/SFB tensor to tiled_mma - sf_kblock_coord = (None, None, kblock_idx) - tiled_mma.set( - tcgen05.Field.SFA, - tCtSFA[sf_kblock_coord].iterator, - ) - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB1[sf_kblock_coord].iterator, - ) - cute.gemm( - tiled_mma, - tCtAcc1, - tCrA[kblock_coord], - tCrB1[kblock_coord], - tCtAcc1, - ) - - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB2[sf_kblock_coord].iterator, - ) - cute.gemm( - tiled_mma, - tCtAcc2, - tCrA[kblock_coord], - tCrB2[kblock_coord], - tCtAcc2, - ) - - # Enable accumulate on tCtAcc1/tCtAcc2 after first kblock - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - - # Async arrive AB buffer empty - ab_full.release() - acc_empty.commit() - - # - # Epilogue - # Partition for epilogue - # - op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) - copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) - tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc1) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - # (T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc1 = thr_copy_t2r.partition_S(tCtAcc1) - # (T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc2 = thr_copy_t2r.partition_S(tCtAcc2) - # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - tTR_gC = thr_copy_t2r.partition_D(tCgC) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc1 = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 - ) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc2 = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 - ) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rC = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype - ) - # STG Atom - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) - tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] - - # Wait for accumulator buffer full - acc_full = acc_consumer.wait_and_advance() - - # Copy accumulator to register - cute.copy(tiled_copy_t2r, tTR_tAcc1, tTR_rAcc1) - cute.copy(tiled_copy_t2r, tTR_tAcc2, tTR_rAcc2) - - # Silu activation on acc1 and multiply with acc2 - acc_vec1 = epilogue_op(tTR_rAcc1.load()) - acc_vec2 = tTR_rAcc2.load() - acc_vec = acc_vec1 * acc_vec2 - - tTR_rC.store(acc_vec.to(c_dtype)) - # Store C to global memory - cute.copy(simt_atom, tTR_rC, tTR_gC) - - acc_full.release() - # Deallocate TMEM - cute.arch.barrier() - tmem.free(acc_tmem_ptr) - return - - -@cute.jit -def my_kernel( - a_ptr: cute.Pointer, - b1_ptr: cute.Pointer, - b2_ptr: cute.Pointer, - sfa_ptr: cute.Pointer, - sfb1_ptr: cute.Pointer, - sfb2_ptr: cute.Pointer, - c_ptr: cute.Pointer, - problem_size: tuple, - epilogue_op: cutlass.Constexpr = lambda x: x - * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), -): - """ - Host-side JIT function to prepare tensors and launch GPU kernel. - """ - m, n, k, l = problem_size - - # Setup attributes that depend on gemm inputs - a_tensor = cute.make_tensor( - a_ptr, - cute.make_layout( - (m, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), - ), - ) - b_tensor1 = cute.make_tensor( - b1_ptr, - cute.make_layout( - (n, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), - ), - ) - b_tensor2 = cute.make_tensor( - b2_ptr, - cute.make_layout( - (n, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), - ), - ) - c_tensor = cute.make_tensor( - c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) - ) - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, sf_vec_size - ) - sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) - - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor1.shape, sf_vec_size - ) - sfb_tensor1 = cute.make_tensor(sfb1_ptr, sfb_layout) - sfb_tensor2 = cute.make_tensor(sfb2_ptr, sfb_layout) - - mma_op = tcgen05.MmaMXF4NVF4Op( - sf_dtype, - (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), - tcgen05.CtaGroup.ONE, - tcgen05.OperandSource.SMEM, - ) - tiled_mma = cute.make_tiled_mma(mma_op) - - cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((1, 1, 1)), - (tiled_mma.thr_id.shape,), - ) - - # Compute A/B/SFA/SFB/C shared memory layout - a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - # B1 and B2 have the same size thus share the same smem layout - b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - # SFB1 and SFB2 have the same size thus share the same smem layout - sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - # Setup TMA for A - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - a_tensor, - a_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for B1 - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b1, tma_tensor_b1 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - b_tensor1, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for B2 - tma_atom_b2, tma_tensor_b2 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - b_tensor2, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for SFA - sfa_smem_layout = cute.slice_( - sfa_smem_layout_staged , (None, None, None, 0) - ) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfa_tensor, - sfa_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB1 - sfb_smem_layout = cute.slice_( - sfb_smem_layout_staged , (None, None, None, 0) - ) - tma_atom_sfb1, tma_tensor_sfb1 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfb_tensor1, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB2 - tma_atom_sfb2, tma_tensor_sfb2 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfb_tensor2, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - - # Compute TMA load bytes - a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) - b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) - sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) - sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) - num_tma_load_bytes = ( - a_copy_size + b_copy_size * 2 + sfa_copy_size + sfb_copy_size * 2 - ) * atom_thr_size - - # Compute grid size - grid = ( - cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), - cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), - c_tensor.shape[2], - ) - - # Launch the kernel. - kernel( - # MMA (Matrix Multiply-Accumulate) configuration - tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern - - # TMA (Tensor Memory Accelerator) atoms and tensors for shared input matrix A - tma_atom_a, # TMA copy atom defining how to load A from global memory - tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) - shared by both GEMMs - - # TMA atoms and tensors for first B matrix (B1) - tma_atom_b1, # TMA copy atom defining how to load B1 from global memory - tma_tensor_b1, # Tensor descriptor for B1 matrix (n, k, l) - first GEMM - - # TMA atoms and tensors for second B matrix (B2) - tma_atom_b2, # TMA copy atom defining how to load B2 from global memory - tma_tensor_b2, # Tensor descriptor for B2 matrix (n, k, l) - second GEMM - - # TMA atoms and tensors for scale factor A (shared) - tma_atom_sfa, # TMA copy atom for loading scale factors for A - tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - shared - - # TMA atoms and tensors for scale factor B1 - tma_atom_sfb1, # TMA copy atom for loading scale factors for B1 - tma_tensor_sfb1, # Tensor descriptor for SFB1 (block scale factors for B1) - - # TMA atoms and tensors for scale factor B2 - tma_atom_sfb2, # TMA copy atom for loading scale factors for B2 - tma_tensor_sfb2, # Tensor descriptor for SFB2 (block scale factors for B2) - - # Output tensor C (stores both C1 and C2 results) - c_tensor, # Output tensor where both GEMM results will be stored (m, n, l) - - # Shared memory layouts with staging for pipelined execution - a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) - b_smem_layout_staged, # Staged shared memory layout for B1/B2 (includes stage dimension) - sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) - sfb_smem_layout_staged, # Staged shared memory layout for SFB1/SFB2 (includes stage dimension) - - # Pipeline synchronization parameter - num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) - - # Epilogue operation - epilogue_op, # Epilogue operation to apply to output (e.g., element-wise ops) - ).launch( - grid=grid, - block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1), - ) - return - - -# Global cache for compiled kernel -_compiled_kernel_cache = None -# This function is used to compile the kernel once and cache it and then allow users to -# run the kernel multiple times to get more accurate timing results. -def compile_kernel(): - """ - Compile the kernel once and cache it. - This should be called before any timing measurements. - - Returns: - The compiled kernel function - """ - global _compiled_kernel_cache - - if _compiled_kernel_cache is not None: - return _compiled_kernel_cache - - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - b1_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - b2_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - sfb1_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - sfb2_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - - # Compile the kernel - _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (0, 0, 0, 0)) - - return _compiled_kernel_cache - - -def custom_kernel(data: input_t) -> output_t: - """ - Execute the block-scaled dual GEMM kernel with silu activation, - C = silu(A @ B1) * (A @ B2). - - This is the main entry point called by the evaluation framework. - It converts PyTorch tensors to CuTe tensors, launches the kernel, - and returns the result. - - Args: - data: Tuple of (a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c) PyTorch tensors - a: [m, k, l] - Input matrix in float4e2m1fn - b1: [n, k, l] - Input matrix in float4e2m1fn - b2: [n, k, l] - Input matrix in float4e2m1fn - sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb1_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb2_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - c: [m, n, l] - Output vector in float16 - - Returns: - Output tensor c with computed results - """ - a, b1, b2, _, _, _, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data - - # Ensure kernel is compiled (will use cached version if available) - # To avoid the compilation overhead, we compile the kernel once and cache it. - compiled_func = compile_kernel() - - # Get dimensions from MxKxL layout - _, k, _ = a.shape - m, n, l = c.shape - # Torch use e2m1_x2 data type, thus k is halved - k = k * 2 - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - b1_ptr = make_ptr( - ab_dtype, b1.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - b2_ptr = make_ptr( - ab_dtype, b2.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb1_ptr = make_ptr( - sf_dtype, sfb1_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb2_ptr = make_ptr( - sf_dtype, sfb2_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - - # Execute the compiled kernel - compiled_func(a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (m, n, k, l)) - - return c diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/task.py b/problems/nvidia/modal_nvfp4_dual_gemm/task.py deleted file mode 100644 index 8facfb07..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/task.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -from typing import TypedDict, TypeVar - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=torch.Tensor) -class TestSpec(TypedDict): - m: int - n: int - k: int - l: int - seed: int \ No newline at end of file diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/task.yml b/problems/nvidia/modal_nvfp4_dual_gemm/task.yml deleted file mode 100644 index 4d36bde6..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/task.yml +++ /dev/null @@ -1,64 +0,0 @@ -# name: nvfp4-dual-gemm - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval_better_bench.py"} - -lang: "py" - -description: | - - You will implement a block scaled dual matrix-matrix multiplication kernel with silu activation optimized for NVIDIA B200. - To be explicit, you will be given a tuple of tensors: - ``` - (a, b1, b2, sfa, sfb1, sfb2, c) - ``` - where: - * `a` is M x K x L in K-major order in nvfp4(e2m1) - * `b1` is N x K x L in K-major order in nvfp4(e2m1) - * `b2` is N x K x L in K-major order in nvfp4(e2m1) - * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `sfb1` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `sfb2` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `c` is M x N x L in fp16 - - Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. - The ranking criteria is the geometric mean of the benchmark results. - For the grand price, your kernel will be evaluated against the speed of light analysis - and the solution closest to the speed of light will be awarded the grand price. - ``` - The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: - M N K L time[us] - 256 4096 7168 1 4.708 - 512 4096 7168 1 8.714 - 256 3072 4096 1 2.125 - 512 3072 7168 1 6.535 - ``` -config: - main: "eval.py" - -templates: - Python: "template.py" - -tests: - - {"m": 1536, "n": 512, "k": 7168, "l": 1, "seed": 1111} - - {"m": 256, "n": 512, "k": 256, "l": 1, "seed": 1111} - - {"m": 1536, "n": 512, "k": 7168, "l": 1, "seed": 1111} - - {"m": 3072, "n": 1024, "k": 1536, "l": 1, "seed": 1111} - - {"m": 7168, "n": 1024, "k": 256, "l": 1, "seed": 1111} - - {"m": 7168, "n": 2304, "k": 2048, "l": 1, "seed": 1111} - - {"m": 4608, "n": 384, "k": 7168, "l": 1, "seed": 1111} - - {"m": 7168, "n": 384, "k": 2304, "l": 1, "seed": 1111} - - {"m": 512, "n": 768, "k": 7168, "l": 1, "seed": 1111} - - {"m": 4096, "n": 768, "k": 512, "l": 1, "seed": 1111} - -benchmarks: - - {"m": 256, "n": 4096, "k": 7168, "l": 1, "seed": 1111} - - {"m": 512, "n": 4096, "k": 7168, "l": 1, "seed": 1111} - - {"m": 256, "n": 3072, "k": 4096, "l": 1, "seed": 1111} - - {"m": 512, "n": 3072, "k": 7168, "l": 1, "seed": 1111} - -ranking_by: "geom" diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/template.py b/problems/nvidia/modal_nvfp4_dual_gemm/template.py deleted file mode 100644 index d8985df5..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/template.py +++ /dev/null @@ -1,28 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference implementation of block-scale fp4 dual gemm with silu activation - Args: - data: Tuple that expands to: - a: torch.Tensor[float4e2m1fn] of shape [m, k, l], - b1: torch.Tensor[float4e2m1fn] of shape [n, k, l], - b2: torch.Tensor[float4e2m1fn] of shape [n, k, l], - sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation - sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation - sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation - sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], - sfb1_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], - sfb2_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], - c: torch.Tensor[float16] of shape [m, n, l] - Returns: - Tensor containing output in float16 - c: torch.Tensor[float16] of shape [m, n, l] - """ - # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. - a, b1, b2, sfa, sfb1, sfb2, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data - - # Your implementation here - - return c \ No newline at end of file diff --git a/problems/nvidia/modal_nvfp4_dual_gemm/utils.py b/problems/nvidia/modal_nvfp4_dual_gemm/utils.py deleted file mode 100644 index d9b3a69e..00000000 --- a/problems/nvidia/modal_nvfp4_dual_gemm/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py deleted file mode 100644 index 95c6aacd..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch -from task import input_t, output_t -from utils import make_match_reference - -# Scaling factor vector size -sf_vec_size = 16 - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - -# Helper function to convert scale factor tensor to blocked format -def to_blocked(input_matrix): - rows, cols = input_matrix.shape - - # Please ensure rows and cols are multiples of 128 and 4 respectively - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - - padded = input_matrix - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - - return rearranged.flatten() - - -def ref_kernel( - data: input_t, -) -> output_t: - """ - PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation, - C = silu(A @ B1) * (A @ B2). - """ - a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data - - # Get dimensions from MxNxL layout - m, n, l = c_ref.shape - - # Call torch._scaled_mm to compute the GEMV result - ref1 = torch.empty( - (l, m, n), - dtype=torch.float32, - device="cuda", - ).permute(1, 2, 0) - ref2 = torch.empty( - (l, m, n), - dtype=torch.float32, - device="cuda", - ).permute(1, 2, 0) - for l_idx in range(l): - # Convert the scale factor tensor to blocked format - scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) - scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx]) - scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx]) - # (m, k) @ (n, k).T -> (m, n) - res1 = torch._scaled_mm( - a_ref[:, :, l_idx], - b1_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b1.cuda(), - bias=None, - out_dtype=torch.float32, - ) - ref1[:, :, l_idx] = res1 - - res2 = torch._scaled_mm( - a_ref[:, :, l_idx], - b2_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b2.cuda(), - bias=None, - out_dtype=torch.float32, - ) - ref2[:, :, l_idx] = res2 - # Do silu on the first GEMM result and multiply with the second GEMM result - c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16) - return c_ref - - -def generate_input( - m: int, - n: int, - k: int, - l: int, - seed: int, -): - """ - Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation, - C = silu(A @ B1) * (A @ B2). - - Args: - m: Number of rows in matrix A - n: Number of columns in matrix B1 and B2 - k: Number of columns in A and rows of B1 and B2 - l: Batch size - seed: Random seed for reproducibility - - Returns: - Tuple of (a, b, scale_a, scale_b, c) where: - a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - c: [m, n, l] - Output matrix in torch.float16 data type - """ - torch.manual_seed(seed) - - def create_fp4_tensors(l, mn, k): - # generate uint8 tensor, then convert to float4e2m1fn_x2 data type - # generate all bit patterns - ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda") - - # for each nibble, only keep the sign bit and 2 LSBs - # the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5] - ref_i8 = ref_i8 & 0b1011_1011 - - return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2) - - # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type - a_ref = create_fp4_tensors(l, m, k) - b1_ref = create_fp4_tensors(l, n, k) - b2_ref = create_fp4_tensors(l, n, k) - a_ref = a_ref.view(torch.float4_e2m1fn_x2) - b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) - b2_ref = b2_ref.view(torch.float4_e2m1fn_x2) - - # Create float16 output tensor - c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( - 1, 2, 0 - ) - - # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. The customized data layout can be found in: - # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout - def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, sf_k, l) on CPU. - ref_shape = (l, mn, sf_k) - ref_permute_order = (1, 2, 0) - # Init with fp32 tensor in [0,1), then convert to float8_e4m3fn - ref_f8_random_fp32 = torch.rand(ref_shape, dtype=torch.float32, device='cuda') - ref_f8_torch_tensor = ref_f8_random_fp32.to(dtype=torch.float8_e4m3fn) - # permute to match ref_permute_order - ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) - - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - - # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout - # Which is needed by the CuTe customized kernel - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.empty(mma_shape, dtype=torch.int8, device='cuda') - reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) - - # GPU-side vectorized reordering (replaces slow CPU nested loops) - # Create index grids for all dimensions - i_idx = torch.arange(mn, device='cuda') - j_idx = torch.arange(sf_k, device='cuda') - b_idx = torch.arange(l, device='cuda') - - # Create meshgrid for all combinations of (i, j, b) - i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') - - # Calculate target indices in vectorized manner - mm = i_grid // (atom_m[0] * atom_m[1]) - mm32 = i_grid % atom_m[0] - mm4 = (i_grid % 128) // atom_m[0] - kk = j_grid // atom_k - kk4 = j_grid % atom_k - - # Perform the reordering with advanced indexing (all on GPU) - reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] - - return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor - - sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) - sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - - return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_dual_gemm/submission.py b/problems/nvidia/nvfp4_dual_gemm/submission.py deleted file mode 100644 index 739cc5a0..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/submission.py +++ /dev/null @@ -1,957 +0,0 @@ -from torch._higher_order_ops.torchbind import call_torchbind_fake -import cuda.bindings.driver as cuda - -import torch -from task import input_t, output_t - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass.cute.runtime import make_ptr - -# Kernel configuration parameters -# Tile sizes for M, N, K dimensions -mma_tiler_mnk= (128, 128, 256) -# Shape of the K dimension for the MMA instruction -mma_inst_shape_k = 64 -# FP4 data type for A and B -ab_dtype = cutlass.Float4E2M1FN -# FP8 data type for scale factors -sf_dtype = cutlass.Float8E4M3FN -# FP16 output type -c_dtype = cutlass.Float16 -# Scale factor block size (16 elements share one scale) -sf_vec_size = 16 -# Number of threads per CUDA thread block -threads_per_cta = 128 -# Stage numbers of shared memory and tmem -num_acc_stage = 1 -num_ab_stage = 1 -# Total number of columns in tmem -num_tmem_alloc_cols = 512 - - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# GPU device kernel -@cute.kernel -def kernel( - tiled_mma: cute.TiledMma, - tma_atom_a: cute.CopyAtom, - mA_mkl: cute.Tensor, - tma_atom_b1: cute.CopyAtom, - mB_nkl1: cute.Tensor, - tma_atom_b2: cute.CopyAtom, - mB_nkl2: cute.Tensor, - tma_atom_sfa: cute.CopyAtom, - mSFA_mkl: cute.Tensor, - tma_atom_sfb1: cute.CopyAtom, - mSFB_nkl1: cute.Tensor, - tma_atom_sfb2: cute.CopyAtom, - mSFB_nkl2: cute.Tensor, - mC_mnl: cute.Tensor, - a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, - sfa_smem_layout_staged: cute.Layout, - sfb_smem_layout_staged: cute.Layout, - num_tma_load_bytes: cutlass.Constexpr[int], - epilogue_op: cutlass.Constexpr = lambda x: x - * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), -): - """ - GPU device kernel performing the batched GEMM computation. - """ - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - tidx = cute.arch.thread_idx() - - # - # Setup cta/thread coordinates - # - # Coords inside cluster - bidx, bidy, bidz = cute.arch.block_idx() - mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) - - # Coords outside cluster - cta_coord = (bidx, bidy, bidz) - mma_tile_coord_mnl = ( - cta_coord[0] // cute.size(tiled_mma.thr_id.shape), - cta_coord[1], - cta_coord[2], - ) - # Coord inside cta - tidx, _, _ = cute.arch.thread_idx() - - # - # Define shared storage for kernel - # - @cute.struct - class SharedStorage: - ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] - acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] - tmem_holding_buf: cutlass.Int32 - - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - # (MMA, MMA_M, MMA_K, STAGE) - sA = smem.allocate_tensor( - element_type=ab_dtype, - layout=a_smem_layout_staged.outer, - byte_alignment=128, - swizzle=a_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB1 = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB2 = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_M, MMA_K, STAGE) - sSFA = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfa_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB1 = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB2 = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - - # - # Initialize mainloop ab_pipeline, acc_pipeline and their states - # - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.ab_mbar_ptr.data_ptr(), - num_stages=num_ab_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=ab_pipeline_consumer_group, - tx_count=num_tma_load_bytes, - ).make_participants() - acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( - barrier_storage=storage.acc_mbar_ptr.data_ptr(), - num_stages=num_acc_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, - threads_per_cta, - ), - ).make_participants() - - # - # Local_tile partition global tensors - # - # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl1 = cute.local_tile( - mB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl2 = cute.local_tile( - mB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - gSFB_nkl1 = cute.local_tile( - mSFB_nkl1, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl2 = cute.local_tile( - mSFB_nkl2, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) - ) - k_tile_cnt = cute.size(gA_mkl, mode=[3]) - - # - # Partition global tensor for TiledMMA_A/B/SFA/SFB/C - # - # (MMA, MMA_M, MMA_K, RestK) - thr_mma = tiled_mma.get_slice(mma_tile_coord_v) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB1 = thr_mma.partition_B(gB_nkl1) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB2 = thr_mma.partition_B(gB_nkl2) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB1 = thr_mma.partition_B(gSFB_nkl1) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB2 = thr_mma.partition_B(gSFB_nkl2) - # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) - tCgC = thr_mma.partition_C(gC_mnl) - - # - # Partition global/shared tensor for TMA load A/B/SFA/SFB - # - # TMA Partition_S/D for A - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsA, tAgA = cpasync.tma_partition( - tma_atom_a, - 0, - cute.make_layout(1), - cute.group_modes(sA, 0, 3), - cute.group_modes(tCgA, 0, 3), - ) - # TMA Partition_S/D for B1 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB1, tBgB1 = cpasync.tma_partition( - tma_atom_b1, - 0, - cute.make_layout(1), - cute.group_modes(sB1, 0, 3), - cute.group_modes(tCgB1, 0, 3), - ) - # TMA Partition_S/D for B2 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB2, tBgB2 = cpasync.tma_partition( - tma_atom_b2, - 0, - cute.make_layout(1), - cute.group_modes(sB2, 0, 3), - cute.group_modes(tCgB2, 0, 3), - ) - # TMA Partition_S/D for SFA - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsSFA, tAgSFA = cpasync.tma_partition( - tma_atom_sfa, - 0, - cute.make_layout(1), - cute.group_modes(sSFA, 0, 3), - cute.group_modes(tCgSFA, 0, 3), - ) - tAsSFA = cute.filter_zeros(tAsSFA) - tAgSFA = cute.filter_zeros(tAgSFA) - # TMA Partition_S/D for SFB1 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB1, tBgSFB1 = cpasync.tma_partition( - tma_atom_sfb1, - 0, - cute.make_layout(1), - cute.group_modes(sSFB1, 0, 3), - cute.group_modes(tCgSFB1, 0, 3), - ) - tBsSFB1 = cute.filter_zeros(tBsSFB1) - tBgSFB1 = cute.filter_zeros(tBgSFB1) - # TMA Partition_S/D for SFB2 - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB2, tBgSFB2 = cpasync.tma_partition( - tma_atom_sfb2, - 0, - cute.make_layout(1), - cute.group_modes(sSFB2, 0, 3), - cute.group_modes(tCgSFB2, 0, 3), - ) - tBsSFB2 = cute.filter_zeros(tBsSFB2) - tBgSFB2 = cute.filter_zeros(tBgSFB2) - - # - # Partition shared/tensor memory tensor for TiledMMA_A/B/C - # - # (MMA, MMA_M, MMA_K, STAGE) - tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB1 = tiled_mma.make_fragment_B(sB1) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB2 = tiled_mma.make_fragment_B(sB2) - # (MMA, MMA_M, MMA_N) - acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) - # (MMA, MMA_M, MMA_N) - tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) - - # - # Alloc tensor memory buffer - # Make ACC1 and ACC2 tmem tensor - # ACC1 += A @ B1 - # ACC2 += A @ B2 - # - tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=1, - num_threads=threads_per_cta, - ) - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=tmem_alloc_barrier, - ) - tmem.allocate(num_tmem_alloc_cols) - tmem.wait_for_alloc() - acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) - tCtAcc1 = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - acc_tmem_ptr1 = cute.recast_ptr( - acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc1), - dtype=cutlass.Float32, - ) - tCtAcc2 = cute.make_tensor(acc_tmem_ptr1, tCtAcc_fake.layout) - - # - # Make SFA/SFB1/SFB2 tmem tensor - # - # SFA tmem layout: (MMA, MMA_M, MMA_K) - tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), - ) - # Get SFA tmem ptr - sfa_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2), - dtype=sf_dtype, - ) - tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) - - # SFB1, SFB2 tmem layout: (MMA, MMA_N, MMA_K) - tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), - ) - # Get SFB1 tmem ptr - sfb_tmem_ptr1 = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), - dtype=sf_dtype, - ) - tCtSFB1 = cute.make_tensor(sfb_tmem_ptr1, tCtSFB_layout) - # Get SFB2 tmem ptr - sfb_tmem_ptr2 = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc1) - + tcgen05.find_tmem_tensor_col_offset(tCtAcc2) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA) - + tcgen05.find_tmem_tensor_col_offset(tCtSFB1), - dtype=sf_dtype, - ) - tCtSFB2 = cute.make_tensor(sfb_tmem_ptr2, tCtSFB_layout) - - # - # Partition for S2T copy of SFA/SFB1/SFB2 - # - # Make S2T CopyAtom - copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), - sf_dtype, - ) - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFA_compact = cute.filter_zeros(sSFA) - # (MMA, MMA_MN, MMA_K) - tCtSFA_compact = cute.filter_zeros(tCtSFA) - tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) - thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) - - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact = cute.filter_zeros(sSFB1) - # (MMA, MMA_MN, MMA_K) - tCtSFB1_compact = cute.filter_zeros(tCtSFB1) - tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB1_compact) - thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB1_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB1_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB1_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB1_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB1_compact) - - # SFB2 S2T copy and partition - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact = cute.filter_zeros(sSFB2) - # (MMA, MMA_MN, MMA_K) - tCtSFB2_compact = cute.filter_zeros(tCtSFB2) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB2_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB2_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB2_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB2_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB2_compact) - - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), RestK) - tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB1 = tBgB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB2 = tBgB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB1 = tBgSFB1[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB2 = tBgSFB2[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - - # - # Execute Data copy and Math computation in the k_tile loop - # - if warp_idx == 0: - # Wait for accumulator buffer empty - acc_empty = acc_producer.acquire_and_advance() - # Set ACCUMULATE field to False for the first k_tile iteration - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # Execute k_tile loop - for k_tile in range(k_tile_cnt): - # Wait for AB buffer empty - ab_empty = ab_producer.acquire_and_advance() - - # TMA load A/B1/B2/SFA/SFB1/SFB2 to shared memory - cute.copy( - tma_atom_a, - tAgA[(None, ab_empty.count)], - tAsA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_b1, - tBgB1[(None, ab_empty.count)], - tBsB1[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_b2, - tBgB2[(None, ab_empty.count)], - tBsB2[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfa, - tAgSFA[(None, ab_empty.count)], - tAsSFA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfb1, - tBgSFB1[(None, ab_empty.count)], - tBsSFB1[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfb2, - tBgSFB2[(None, ab_empty.count)], - tBsSFB2[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - - # Wait for AB buffer full - ab_full = ab_consumer.wait_and_advance() - - # Copy SFA/SFB1/SFB2 to tmem - s2t_stage_coord = (None, None, None, None, ab_full.index) - tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] - tCsSFB1_compact_s2t_staged = tCsSFB1_compact_s2t[s2t_stage_coord] - tCsSFB2_compact_s2t_staged = tCsSFB2_compact_s2t[s2t_stage_coord] - cute.copy( - tiled_copy_s2t_sfa, - tCsSFA_compact_s2t_staged, - tCtSFA_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB1_compact_s2t_staged, - tCtSFB1_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB2_compact_s2t_staged, - tCtSFB2_compact_s2t, - ) - - # tCtAcc1 += tCrA * tCrSFA * tCrB1 * tCrSFB1 - # tCtAcc2 += tCrA * tCrSFA * tCrB2 * tCrSFB2 - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = ( - None, - None, - kblock_idx, - ab_full.index, - ) - - # Set SFA/SFB tensor to tiled_mma - sf_kblock_coord = (None, None, kblock_idx) - tiled_mma.set( - tcgen05.Field.SFA, - tCtSFA[sf_kblock_coord].iterator, - ) - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB1[sf_kblock_coord].iterator, - ) - cute.gemm( - tiled_mma, - tCtAcc1, - tCrA[kblock_coord], - tCrB1[kblock_coord], - tCtAcc1, - ) - - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB2[sf_kblock_coord].iterator, - ) - cute.gemm( - tiled_mma, - tCtAcc2, - tCrA[kblock_coord], - tCrB2[kblock_coord], - tCtAcc2, - ) - - # Enable accumulate on tCtAcc1/tCtAcc2 after first kblock - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - - # Async arrive AB buffer empty - ab_full.release() - acc_empty.commit() - - # - # Epilogue - # Partition for epilogue - # - op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) - copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) - tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc1) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - # (T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc1 = thr_copy_t2r.partition_S(tCtAcc1) - # (T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc2 = thr_copy_t2r.partition_S(tCtAcc2) - # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - tTR_gC = thr_copy_t2r.partition_D(tCgC) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc1 = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 - ) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc2 = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 - ) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rC = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype - ) - # STG Atom - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) - tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] - - # Wait for accumulator buffer full - acc_full = acc_consumer.wait_and_advance() - - # Copy accumulator to register - cute.copy(tiled_copy_t2r, tTR_tAcc1, tTR_rAcc1) - cute.copy(tiled_copy_t2r, tTR_tAcc2, tTR_rAcc2) - - # Silu activation on acc1 and multiply with acc2 - acc_vec1 = epilogue_op(tTR_rAcc1.load()) - acc_vec2 = tTR_rAcc2.load() - acc_vec = acc_vec1 * acc_vec2 - - tTR_rC.store(acc_vec.to(c_dtype)) - # Store C to global memory - cute.copy(simt_atom, tTR_rC, tTR_gC) - - acc_full.release() - # Deallocate TMEM - cute.arch.barrier() - tmem.free(acc_tmem_ptr) - return - - -@cute.jit -def my_kernel( - a_ptr: cute.Pointer, - b1_ptr: cute.Pointer, - b2_ptr: cute.Pointer, - sfa_ptr: cute.Pointer, - sfb1_ptr: cute.Pointer, - sfb2_ptr: cute.Pointer, - c_ptr: cute.Pointer, - problem_size: tuple, - epilogue_op: cutlass.Constexpr = lambda x: x - * (1.0 / (1.0 + cute.math.exp(-x, fastmath=True))), -): - """ - Host-side JIT function to prepare tensors and launch GPU kernel. - """ - m, n, k, l = problem_size - - # Setup attributes that depend on gemm inputs - a_tensor = cute.make_tensor( - a_ptr, - cute.make_layout( - (m, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), - ), - ) - b_tensor1 = cute.make_tensor( - b1_ptr, - cute.make_layout( - (n, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), - ), - ) - b_tensor2 = cute.make_tensor( - b2_ptr, - cute.make_layout( - (n, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), - ), - ) - c_tensor = cute.make_tensor( - c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) - ) - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, sf_vec_size - ) - sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) - - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor1.shape, sf_vec_size - ) - sfb_tensor1 = cute.make_tensor(sfb1_ptr, sfb_layout) - sfb_tensor2 = cute.make_tensor(sfb2_ptr, sfb_layout) - - mma_op = tcgen05.MmaMXF4NVF4Op( - sf_dtype, - (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), - tcgen05.CtaGroup.ONE, - tcgen05.OperandSource.SMEM, - ) - tiled_mma = cute.make_tiled_mma(mma_op) - - cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((1, 1, 1)), - (tiled_mma.thr_id.shape,), - ) - - # Compute A/B/SFA/SFB/C shared memory layout - a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - # B1 and B2 have the same size thus share the same smem layout - b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - # SFB1 and SFB2 have the same size thus share the same smem layout - sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - # Setup TMA for A - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - a_tensor, - a_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for B1 - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b1, tma_tensor_b1 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - b_tensor1, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for B2 - tma_atom_b2, tma_tensor_b2 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - b_tensor2, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - ) - # Setup TMA for SFA - sfa_smem_layout = cute.slice_( - sfa_smem_layout_staged , (None, None, None, 0) - ) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfa_tensor, - sfa_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB1 - sfb_smem_layout = cute.slice_( - sfb_smem_layout_staged , (None, None, None, 0) - ) - tma_atom_sfb1, tma_tensor_sfb1 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfb_tensor1, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB2 - tma_atom_sfb2, tma_tensor_sfb2 = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfb_tensor2, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk .shape, - internal_type=cutlass.Int16, - ) - - # Compute TMA load bytes - a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) - b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) - sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) - sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) - num_tma_load_bytes = ( - a_copy_size + b_copy_size * 2 + sfa_copy_size + sfb_copy_size * 2 - ) * atom_thr_size - - # Compute grid size - grid = ( - cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), - cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), - c_tensor.shape[2], - ) - - # Launch the kernel. - kernel( - # MMA (Matrix Multiply-Accumulate) configuration - tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern - - # TMA (Tensor Memory Accelerator) atoms and tensors for shared input matrix A - tma_atom_a, # TMA copy atom defining how to load A from global memory - tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) - shared by both GEMMs - - # TMA atoms and tensors for first B matrix (B1) - tma_atom_b1, # TMA copy atom defining how to load B1 from global memory - tma_tensor_b1, # Tensor descriptor for B1 matrix (n, k, l) - first GEMM - - # TMA atoms and tensors for second B matrix (B2) - tma_atom_b2, # TMA copy atom defining how to load B2 from global memory - tma_tensor_b2, # Tensor descriptor for B2 matrix (n, k, l) - second GEMM - - # TMA atoms and tensors for scale factor A (shared) - tma_atom_sfa, # TMA copy atom for loading scale factors for A - tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - shared - - # TMA atoms and tensors for scale factor B1 - tma_atom_sfb1, # TMA copy atom for loading scale factors for B1 - tma_tensor_sfb1, # Tensor descriptor for SFB1 (block scale factors for B1) - - # TMA atoms and tensors for scale factor B2 - tma_atom_sfb2, # TMA copy atom for loading scale factors for B2 - tma_tensor_sfb2, # Tensor descriptor for SFB2 (block scale factors for B2) - - # Output tensor C (stores both C1 and C2 results) - c_tensor, # Output tensor where both GEMM results will be stored (m, n, l) - - # Shared memory layouts with staging for pipelined execution - a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) - b_smem_layout_staged, # Staged shared memory layout for B1/B2 (includes stage dimension) - sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) - sfb_smem_layout_staged, # Staged shared memory layout for SFB1/SFB2 (includes stage dimension) - - # Pipeline synchronization parameter - num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) - - # Epilogue operation - epilogue_op, # Epilogue operation to apply to output (e.g., element-wise ops) - ).launch( - grid=grid, - block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1), - ) - return - - -# Global cache for compiled kernel -_compiled_kernel_cache = None -# This function is used to compile the kernel once and cache it and then allow users to -# run the kernel multiple times to get more accurate timing results. -def compile_kernel(): - """ - Compile the kernel once and cache it. - This should be called before any timing measurements. - - Returns: - The compiled kernel function - """ - global _compiled_kernel_cache - - if _compiled_kernel_cache is not None: - return _compiled_kernel_cache - - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - b1_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - b2_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - sfb1_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - sfb2_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - - # Compile the kernel - _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (0, 0, 0, 0)) - - return _compiled_kernel_cache - - -def custom_kernel(data: input_t) -> output_t: - """ - Execute the block-scaled dual GEMM kernel with silu activation, - C = silu(A @ B1) * (A @ B2). - - This is the main entry point called by the evaluation framework. - It converts PyTorch tensors to CuTe tensors, launches the kernel, - and returns the result. - - Args: - data: Tuple of (a, b1, b2, sfa_cpu, sfb1_cpu, sfb2_cpu, c) PyTorch tensors - a: [m, k, l] - Input matrix in float4e2m1fn - b1: [n, k, l] - Input matrix in float4e2m1fn - b2: [n, k, l] - Input matrix in float4e2m1fn - sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfb1_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfb2_cpu: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb1_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb2_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - c: [m, n, l] - Output vector in float16 - - Returns: - Output tensor c with computed results - """ - a, b1, b2, _, _, _, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data - - # Ensure kernel is compiled (will use cached version if available) - # To avoid the compilation overhead, we compile the kernel once and cache it. - compiled_func = compile_kernel() - - # Get dimensions from MxKxL layout - _, k, _ = a.shape - m, n, l = c.shape - # Torch use e2m1_x2 data type, thus k is halved - k = k * 2 - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - b1_ptr = make_ptr( - ab_dtype, b1.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - b2_ptr = make_ptr( - ab_dtype, b2.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb1_ptr = make_ptr( - sf_dtype, sfb1_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb2_ptr = make_ptr( - sf_dtype, sfb2_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - - # Execute the compiled kernel - compiled_func(a_ptr, b1_ptr, b2_ptr, sfa_ptr, sfb1_ptr, sfb2_ptr, c_ptr, (m, n, k, l)) - - return c diff --git a/problems/nvidia/nvfp4_dual_gemm/task.py b/problems/nvidia/nvfp4_dual_gemm/task.py deleted file mode 100644 index 8facfb07..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/task.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -from typing import TypedDict, TypeVar - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=torch.Tensor) -class TestSpec(TypedDict): - m: int - n: int - k: int - l: int - seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/task.yml b/problems/nvidia/nvfp4_dual_gemm/task.yml deleted file mode 100644 index 4d36bde6..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/task.yml +++ /dev/null @@ -1,64 +0,0 @@ -# name: nvfp4-dual-gemm - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval_better_bench.py"} - -lang: "py" - -description: | - - You will implement a block scaled dual matrix-matrix multiplication kernel with silu activation optimized for NVIDIA B200. - To be explicit, you will be given a tuple of tensors: - ``` - (a, b1, b2, sfa, sfb1, sfb2, c) - ``` - where: - * `a` is M x K x L in K-major order in nvfp4(e2m1) - * `b1` is N x K x L in K-major order in nvfp4(e2m1) - * `b2` is N x K x L in K-major order in nvfp4(e2m1) - * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `sfb1` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `sfb2` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `c` is M x N x L in fp16 - - Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. - The ranking criteria is the geometric mean of the benchmark results. - For the grand price, your kernel will be evaluated against the speed of light analysis - and the solution closest to the speed of light will be awarded the grand price. - ``` - The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: - M N K L time[us] - 256 4096 7168 1 4.708 - 512 4096 7168 1 8.714 - 256 3072 4096 1 2.125 - 512 3072 7168 1 6.535 - ``` -config: - main: "eval.py" - -templates: - Python: "template.py" - -tests: - - {"m": 1536, "n": 512, "k": 7168, "l": 1, "seed": 1111} - - {"m": 256, "n": 512, "k": 256, "l": 1, "seed": 1111} - - {"m": 1536, "n": 512, "k": 7168, "l": 1, "seed": 1111} - - {"m": 3072, "n": 1024, "k": 1536, "l": 1, "seed": 1111} - - {"m": 7168, "n": 1024, "k": 256, "l": 1, "seed": 1111} - - {"m": 7168, "n": 2304, "k": 2048, "l": 1, "seed": 1111} - - {"m": 4608, "n": 384, "k": 7168, "l": 1, "seed": 1111} - - {"m": 7168, "n": 384, "k": 2304, "l": 1, "seed": 1111} - - {"m": 512, "n": 768, "k": 7168, "l": 1, "seed": 1111} - - {"m": 4096, "n": 768, "k": 512, "l": 1, "seed": 1111} - -benchmarks: - - {"m": 256, "n": 4096, "k": 7168, "l": 1, "seed": 1111} - - {"m": 512, "n": 4096, "k": 7168, "l": 1, "seed": 1111} - - {"m": 256, "n": 3072, "k": 4096, "l": 1, "seed": 1111} - - {"m": 512, "n": 3072, "k": 7168, "l": 1, "seed": 1111} - -ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_dual_gemm/template.py b/problems/nvidia/nvfp4_dual_gemm/template.py deleted file mode 100644 index d8985df5..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/template.py +++ /dev/null @@ -1,28 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference implementation of block-scale fp4 dual gemm with silu activation - Args: - data: Tuple that expands to: - a: torch.Tensor[float4e2m1fn] of shape [m, k, l], - b1: torch.Tensor[float4e2m1fn] of shape [n, k, l], - b2: torch.Tensor[float4e2m1fn] of shape [n, k, l], - sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], used by reference implementation - sfb1: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation - sfb2: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], used by reference implementation - sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], - sfb1_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], - sfb2_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], - c: torch.Tensor[float16] of shape [m, n, l] - Returns: - Tensor containing output in float16 - c: torch.Tensor[float16] of shape [m, n, l] - """ - # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. - a, b1, b2, sfa, sfb1, sfb2, sfa_permuted, sfb1_permuted, sfb2_permuted, c = data - - # Your implementation here - - return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_dual_gemm/utils.py b/problems/nvidia/nvfp4_dual_gemm/utils.py deleted file mode 100644 index d9b3a69e..00000000 --- a/problems/nvidia/nvfp4_dual_gemm/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy diff --git a/problems/nvidia/nvfp4_gemm/reference.py b/problems/nvidia/nvfp4_gemm/reference.py deleted file mode 100644 index 421db5ef..00000000 --- a/problems/nvidia/nvfp4_gemm/reference.py +++ /dev/null @@ -1,161 +0,0 @@ -import torch -from task import input_t, output_t -from utils import make_match_reference - -# Scaling factor vector size -sf_vec_size = 16 - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - -# Helper function to convert scale factor tensor to blocked format -def to_blocked(input_matrix): - rows, cols = input_matrix.shape - - # Please ensure rows and cols are multiples of 128 and 4 respectively - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - - padded = input_matrix - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - - return rearranged.flatten() - - -def ref_kernel( - data: input_t, -) -> output_t: - """ - PyTorch reference implementation of NVFP4 block-scaled GEMM. - """ - a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data - - # Get dimensions from MxNxL layout - _, _, l = c_ref.shape - - # Call torch._scaled_mm to compute the GEMM result - for l_idx in range(l): - # Convert the scale factor tensor to blocked format - scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) - scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) - # (m, k) @ (n, k).T -> (m, n) - res = torch._scaled_mm( - a_ref[:, :, l_idx], - b_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b.cuda(), - bias=None, - out_dtype=torch.float16, - ) - c_ref[:, :, l_idx] = res - return c_ref - - -def generate_input( - m: int, - n: int, - k: int, - l: int, - seed: int, -): - """ - Generate input tensors for NVFP4 block-scaled GEMM. - - Args: - m: Number of rows in matrix A - n: Number of columns in matrix B - k: Number of columns in A and rows of B - l: Batch size - seed: Random seed for reproducibility - - Returns: - Tuple of (a, b, scale_a, scale_b, c) where: - a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b: [n, k, l] - Input scale factors in torch.float8e4m3fn data type - scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - scale_b_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - c: [m, n, l] - Output matrix in torch.float16 data type - """ - torch.manual_seed(seed) - - # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type - a_ref = torch.randint( - -128, 128, (l, m, k // 2), dtype=torch.int8, device="cuda" - ).permute(1, 2, 0) - b_ref = torch.randint( - -128, 128, (l, n, k // 2), dtype=torch.int8, device="cuda" - ).permute(1, 2, 0) - a_ref = a_ref.view(torch.float4_e2m1fn_x2) - b_ref = b_ref.view(torch.float4_e2m1fn_x2) - - # Create float16 output tensor - c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute( - 1, 2, 0 - ) - - # Helper function to prepare the scale factor tensors for both reference - # kernel and customize kernel. The customized data layout can be found in: - # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout - def create_scale_factor_tensors(l, mn, sf_k): - # Create the reference scale factor tensor (mn, sf_k, l) on CPU. - ref_shape = (l, mn, sf_k) - ref_permute_order = (1, 2, 0) - # Init with uint8 tensor, then convert to float8_e4m3fn - ref_f8_random_int = torch.randint(0, 4, ref_shape, dtype=torch.int8, device='cuda') - ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) - # permute to match ref_permute_order - ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order) - - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - - # Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout - # Which is needed by the CuTe customized kernel - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(0, 4, mma_shape, dtype=torch.int8, device='cuda') - reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order) - - # GPU-side vectorized reordering (replaces slow CPU nested loops) - # Create index grids for all dimensions - i_idx = torch.arange(mn, device='cuda') - j_idx = torch.arange(sf_k, device='cuda') - b_idx = torch.arange(l, device='cuda') - - # Create meshgrid for all combinations of (i, j, b) - i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') - - # Calculate target indices in vectorized manner - mm = i_grid // (atom_m[0] * atom_m[1]) - mm32 = i_grid % atom_m[0] - mm4 = (i_grid % 128) // atom_m[0] - kk = j_grid // atom_k - kk4 = j_grid % atom_k - - # Perform the reordering with advanced indexing (all on GPU) - reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid] - - return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor - - sf_k = ceil_div(k, sf_vec_size) - sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k) - sfb_ref_cpu, sfb_ref_permuted = create_scale_factor_tensors(l, n, sf_k) - - return (a_ref, b_ref, sfa_ref_cpu.to("cuda"), sfb_ref_cpu.to("cuda"), sfa_ref_permuted, sfb_ref_permuted, c_ref) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_gemm/submission.py b/problems/nvidia/nvfp4_gemm/submission.py deleted file mode 100644 index c2f37d92..00000000 --- a/problems/nvidia/nvfp4_gemm/submission.py +++ /dev/null @@ -1,761 +0,0 @@ -from torch._higher_order_ops.torchbind import call_torchbind_fake -import cuda.bindings.driver as cuda - -import torch -from task import input_t, output_t - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass.cute.runtime import make_ptr - -# Kernel configuration parameters -# Tile sizes for M, N, K dimensions -mma_tiler_mnk = (128, 128, 256) -# Shape of the K dimension for the MMA instruction -mma_inst_shape_k = 64 -# FP4 data type for A and B -ab_dtype = cutlass.Float4E2M1FN -# FP8 data type for scale factors -sf_dtype = cutlass.Float8E4M3FN -# FP16 output type -c_dtype = cutlass.Float16 -# Scale factor block size (16 elements share one scale) -sf_vec_size = 16 -# Number of threads per CUDA thread block -threads_per_cta = 128 -# Stage numbers of shared memory and tmem -num_acc_stage = 1 -num_ab_stage = 1 -# Total number of columns in tmem -num_tmem_alloc_cols = 512 - - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# The CuTe reference implementation for NVFP4 block-scaled GEMM -@cute.kernel -def kernel( - tiled_mma: cute.TiledMma, - tma_atom_a: cute.CopyAtom, - mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, - tma_atom_sfa: cute.CopyAtom, - mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, - mC_mnl: cute.Tensor, - a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, - sfa_smem_layout_staged: cute.Layout, - sfb_smem_layout_staged: cute.Layout, - num_tma_load_bytes: cutlass.Constexpr[int], -): - """ - GPU device kernel performing the batched GEMM computation. - """ - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - tidx = cute.arch.thread_idx() - - # - # Setup cta/thread coordinates - # - # Coords inside cluster - bidx, bidy, bidz = cute.arch.block_idx() - - # Coords outside cluster - cta_coord = (bidx, bidy, bidz) - mma_tile_coord_mnl = ( - cta_coord[0] // cute.size(tiled_mma.thr_id.shape), - cta_coord[1], - cta_coord[2], - ) - # Coord inside cta - tidx, _, _ = cute.arch.thread_idx() - - # - # Define shared storage for kernel - # - @cute.struct - class SharedStorage: - ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] - acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] - tmem_holding_buf: cutlass.Int32 - - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - # (MMA, MMA_M, MMA_K, STAGE) - sA = smem.allocate_tensor( - element_type=ab_dtype, - layout=a_smem_layout_staged.outer, - byte_alignment=128, - swizzle=a_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_M, MMA_K, STAGE) - sSFA = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfa_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - - # - # Initialize mainloop ab_pipeline, acc_pipeline and their states - # - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.ab_mbar_ptr.data_ptr(), - num_stages=num_ab_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=ab_pipeline_consumer_group, - tx_count=num_tma_load_bytes, - ).make_participants() - acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( - barrier_storage=storage.acc_mbar_ptr.data_ptr(), - num_stages=num_acc_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, - threads_per_cta, - ), - ).make_participants() - - # - # Local_tile partition global tensors - # - # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) - ) - k_tile_cnt = cute.size(gA_mkl, mode=[3]) - - # - # Partition global tensor for TiledMMA_A/B/SFA/SFB/C - # - # (MMA, MMA_M, MMA_K, RestK) - thr_mma = tiled_mma.get_slice(0) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma.partition_B(gSFB_nkl) - # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) - tCgC = thr_mma.partition_C(gC_mnl) - - # - # Partition global/shared tensor for TMA load A/B/SFA/SFB - # - # TMA Partition_S/D for A - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsA, tAgA = cpasync.tma_partition( - tma_atom_a, - 0, - cute.make_layout(1), - cute.group_modes(sA, 0, 3), - cute.group_modes(tCgA, 0, 3), - ) - # TMA Partition_S/D for B - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, - 0, - cute.make_layout(1), - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) - # TMA Partition_S/D for SFA - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsSFA, tAgSFA = cpasync.tma_partition( - tma_atom_sfa, - 0, - cute.make_layout(1), - cute.group_modes(sSFA, 0, 3), - cute.group_modes(tCgSFA, 0, 3), - ) - tAsSFA = cute.filter_zeros(tAsSFA) - tAgSFA = cute.filter_zeros(tAgSFA) - # TMA Partition_S/D for SFB - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cpasync.tma_partition( - tma_atom_sfb, - 0, - cute.make_layout(1), - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) - - # - # Partition shared/tensor memory tensor for TiledMMA_A/B/C - # - # (MMA, MMA_M, MMA_K, STAGE) - tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB = tiled_mma.make_fragment_B(sB) - # (MMA, MMA_M, MMA_N) - acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) - # (MMA, MMA_M, MMA_N) - tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) - - # - # Alloc tensor memory buffer - # - tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=1, - num_threads=threads_per_cta, - ) - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=tmem_alloc_barrier, - ) - tmem.allocate(num_tmem_alloc_cols) - tmem.wait_for_alloc() - acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) - tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - - # - # Make SFA/SFB tmem tensor - # - # Get SFA tmem ptr - sfa_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), - dtype=sf_dtype, - ) - # (MMA, MMA_M, MMA_K) - tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), - ) - tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) - # Get SFB tmem ptr - sfb_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), - dtype=sf_dtype, - ) - # (MMA, MMA_N, MMA_K) - tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), - ) - tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) - - # - # Partition for S2T copy of SFA/SFB - # - # Make S2T CopyAtom - copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), - sf_dtype, - ) - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFA_compact = cute.filter_zeros(sSFA) - # (MMA, MMA_MN, MMA_K) - tCtSFA_compact = cute.filter_zeros(tCtSFA) - tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) - thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) - - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB_compact = cute.filter_zeros(sSFB) - # (MMA, MMA_MN, MMA_K) - tCtSFB_compact = cute.filter_zeros(tCtSFB) - tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) - thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) - - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), RestK) - tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - - # - # Execute Data copy and Math computation in the k_tile loop - # - if warp_idx == 0: - # Wait for accumulator buffer empty - acc_empty = acc_producer.acquire_and_advance() - # Set ACCUMULATE field to False for the first k_tile iteration - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # Execute k_tile loop - for k_tile in range(k_tile_cnt): - # Wait for AB buffer empty - ab_empty = ab_producer.acquire_and_advance() - - # TMA load A/B/SFA/SFB to shared memory - cute.copy( - tma_atom_a, - tAgA[(None, k_tile)], - tAsA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_b, - tBgB[(None, k_tile)], - tBsB[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfa, - tAgSFA[(None, k_tile)], - tAsSFA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - cute.copy( - tma_atom_sfb, - tBgSFB[(None, k_tile)], - tBsSFB[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - ) - - # Wait for AB buffer full - ab_full = ab_consumer.wait_and_advance() - - # Copy SFA/SFB from shared memory to TMEM - s2t_stage_coord = (None, None, None, None, ab_full.index) - tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] - tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] - cute.copy( - tiled_copy_s2t_sfa, - tCsSFA_compact_s2t_staged, - tCtSFA_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB_compact_s2t_staged, - tCtSFB_compact_s2t, - ) - - # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = ( - None, - None, - kblock_idx, - ab_full.index, - ) - - # Set SFA/SFB tensor to tiled_mma - sf_kblock_coord = (None, None, kblock_idx) - tiled_mma.set( - tcgen05.Field.SFA, - tCtSFA[sf_kblock_coord].iterator, - ) - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB[sf_kblock_coord].iterator, - ) - - cute.gemm( - tiled_mma, - tCtAcc, - tCrA[kblock_coord], - tCrB[kblock_coord], - tCtAcc, - ) - # Enable accumulate on tCtAcc after first kblock - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - - # Async arrive AB buffer empty - ab_full.release() - acc_empty.commit() - - # - # Epilogue - # Partition for epilogue - # - op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) - copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) - tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - # (T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc) - # (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - tTR_gC = thr_copy_t2r.partition_D(tCgC) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rAcc = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32 - ) - # (T2R_M, T2R_N, EPI_M, EPI_N) - tTR_rC = cute.make_rmem_tensor( - tTR_gC[None, None, None, None, 0, 0, 0].shape, c_dtype - ) - # STG Atom - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), c_dtype) - tTR_gC = tTR_gC[(None, None, None, None, *mma_tile_coord_mnl)] - - # Wait for accumulator buffer full - acc_full = acc_consumer.wait_and_advance() - - # Copy accumulator to register - cute.copy(tiled_copy_t2r, tTR_tAcc, tTR_rAcc) - acc_vec = tTR_rAcc.load().to(c_dtype) - tTR_rC.store(acc_vec) - # Store C to global memory - cute.copy(simt_atom, tTR_rC, tTR_gC) - - acc_full.release() - - # Deallocate TMEM - cute.arch.barrier() - tmem.free(acc_tmem_ptr) - - return - - -@cute.jit -def my_kernel( - a_ptr: cute.Pointer, - b_ptr: cute.Pointer, - sfa_ptr: cute.Pointer, - sfb_ptr: cute.Pointer, - c_ptr: cute.Pointer, - problem_size: tuple, -): - """ - Host-side JIT function to prepare tensors and launch GPU kernel. - """ - m, n, k, l = problem_size - - # Setup attributes that depend on gemm inputs - a_tensor = cute.make_tensor( - a_ptr, - cute.make_layout( - (m, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), - ), - ) - b_tensor = cute.make_tensor( - b_ptr, - cute.make_layout( - (n, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32)), - ), - ) - c_tensor = cute.make_tensor( - c_ptr, cute.make_layout((cute.assume(m, 32), n, l), stride=(n, 1, m * n)) - ) - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, sf_vec_size - ) - sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) - - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, sf_vec_size - ) - sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) - - mma_op = tcgen05.MmaMXF4NVF4Op( - sf_dtype, - (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), - tcgen05.CtaGroup.ONE, - tcgen05.OperandSource.SMEM, - ) - tiled_mma = cute.make_tiled_mma(mma_op) - - cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((1, 1, 1)), - (tiled_mma.thr_id.shape,), - ) - - # Compute A/B/SFA/SFB/C shared memory layout - a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - # Setup TMA for A - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - a_tensor, - a_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - ) - # Setup TMA for B - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - b_tensor, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - ) - # Setup TMA for SFA - sfa_smem_layout = cute.slice_( - sfa_smem_layout_staged, (None, None, None, 0) - ) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfa_tensor, - sfa_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB - sfb_smem_layout = cute.slice_( - sfb_smem_layout_staged, (None, None, None, 0) - ) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - sfb_tensor, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - internal_type=cutlass.Int16, - ) - - # Compute TMA load bytes - a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) - b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) - sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) - sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) - num_tma_load_bytes = ( - a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size - ) * atom_thr_size - - # Compute grid size - grid = ( - cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]), - cute.ceil_div(c_tensor.shape[1], mma_tiler_mnk[1]), - c_tensor.shape[2], - ) - - # Launch the kernel - kernel( - # MMA (Matrix Multiply-Accumulate) configuration - tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern - - # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A - tma_atom_a, # TMA copy atom defining how to load A from global memory - tma_tensor_a, # Tensor descriptor for A matrix (m, k, l) - - # TMA atoms and tensors for input matrix B - tma_atom_b, # TMA copy atom defining how to load B from global memory - tma_tensor_b, # Tensor descriptor for B matrix (n, k, l) - - # TMA atoms and tensors for scale factor A - tma_atom_sfa, # TMA copy atom for loading scale factors for A - tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - - # TMA atoms and tensors for scale factor B - tma_atom_sfb, # TMA copy atom for loading scale factors for B - tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) - - # Output tensor C - c_tensor, # Output tensor C where result will be stored (m, n, l) - - # Shared memory layouts with staging for pipelined execution - a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) - b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) - sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) - sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) - - # Pipeline synchronization parameter - num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) - ).launch( - grid=grid, - block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1), - ) - return - - -# Global cache for compiled kernel -_compiled_kernel_cache = None -# This function is used to compile the kernel once and cache it and then allow users to -# run the kernel multiple times to get more accurate timing results. -def compile_kernel(): - """ - Compile the kernel once and cache it. - This should be called before any timing measurements. - - Returns: - The compiled kernel function - """ - global _compiled_kernel_cache - - if _compiled_kernel_cache is not None: - return _compiled_kernel_cache - - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - b_ptr = make_ptr( - ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - sfb_ptr = make_ptr( - sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 - ) - - # Compile the kernel - _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) - - return _compiled_kernel_cache - - -def custom_kernel(data: input_t) -> output_t: - """ - Execute the block-scaled GEMM kernel. - - This is the main entry point called by the evaluation framework. - It converts PyTorch tensors to CuTe tensors, launches the kernel, - and returns the result. - - Args: - data: Tuple of (a, b, sfa_ref, sfb_ref, sfa_permuted, sfb_permuted, c) PyTorch tensors - a: [m, k, l] - Input matrix in float4e2m1fn - b: [n, k, l] - Input vector in float4e2m1fn - sfa_ref: [m, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfb_ref: [n, k, l] - Scale factors in float8_e4m3fn, used by reference implementation - sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - c: [m, n, l] - Output vector in float16 - - Returns: - Output tensor c with computed results - """ - a, b, _, _, sfa_permuted, sfb_permuted, c = data - - # Ensure kernel is compiled (will use cached version if available) - # To avoid the compilation overhead, we compile the kernel once and cache it. - compiled_func = compile_kernel() - - # Get dimensions from MxKxL layout - m, k, l = a.shape - n, _, _ = b.shape - # Torch use e2m1_x2 data type, thus k is halved - k = k * 2 - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr( - ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - b_ptr = make_ptr( - ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - c_ptr = make_ptr( - c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) - sfa_ptr = make_ptr( - sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb_ptr = make_ptr( - sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - - # Execute the compiled kernel - compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) - - return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.py b/problems/nvidia/nvfp4_gemm/task.py deleted file mode 100644 index 66db7351..00000000 --- a/problems/nvidia/nvfp4_gemm/task.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -from typing import TypedDict, TypeVar - -input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -output_t = TypeVar("output_t", bound=torch.Tensor) -class TestSpec(TypedDict): - m: int - n: int - k: int - l: int - seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/task.yml b/problems/nvidia/nvfp4_gemm/task.yml deleted file mode 100644 index aca40fd6..00000000 --- a/problems/nvidia/nvfp4_gemm/task.yml +++ /dev/null @@ -1,58 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "../utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval_better_bench.py"} - -lang: "py" - -description: | - - You will implement a block scaled matrix-matrix multiplication kernel optimized for NVIDIA B200. - To be explicit, you will be given a tuple of tensors: - ``` - (a, b, sfa, sfb, c) - ``` - where: - * `a` is M x K x L in K-major order in nvfp4(e2m1) - * `b` is N x K x L in K-major order in nvfp4(e2m1) - * `sfa` is M x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `sfb` is N x (K // 16) x L in K-major order in fp8(e4m3fnuz) - * `c` is M x N x L in fp16 - - Matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. - The ranking criteria is the geometric mean of the benchmark results. - For the grand price, your kernel will be evaluated against the speed of light analysis - and the solution closest to the speed of light will be awarded the grand price. - ``` - The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock: - M N K L time[us] - 128 7168 16384 1 8.994 - 128 4096 7168 1 2.354 - 128 7168 2048 1 1.333 - ``` -config: - main: "eval.py" - -templates: - Python: "template.py" - -tests: - - {"m": 128, "n": 256, "k": 256, "l": 1, "seed": 1111} - - {"m": 128, "n": 1536, "k": 7168, "l": 1, "seed": 1111} - - {"m": 128, "n": 3072, "k": 1536, "l": 1, "seed": 1111} - - {"m": 256, "n": 7168, "k": 256, "l": 1, "seed": 1111} - - {"m": 256, "n": 7168, "k": 2048, "l": 1, "seed": 1111} - - {"m": 2304, "n": 4608, "k": 7168, "l": 1, "seed": 1111} - - {"m": 384, "n": 7168, "k": 2304, "l": 1, "seed": 1111} - - {"m": 512, "n": 512, "k": 7168, "l": 1, "seed": 1111} - - {"m": 512, "n": 4096, "k": 512, "l": 1, "seed": 1111} - - {"m": 512, "n": 1536, "k": 7168, "l": 1, "seed": 1111} - -benchmarks: - - {"m": 128, "n": 7168, "k": 16384, "l": 1, "seed": 1111} - - {"m": 128, "n": 4096, "k": 7168, "l": 1, "seed": 1111} - - {"m": 128, "n": 7168, "k": 2048, "l": 1, "seed": 1111} - -ranking_by: "geom" diff --git a/problems/nvidia/nvfp4_gemm/template.py b/problems/nvidia/nvfp4_gemm/template.py deleted file mode 100644 index 3855d694..00000000 --- a/problems/nvidia/nvfp4_gemm/template.py +++ /dev/null @@ -1,25 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference implementation of block-scale fp4 gemm - Args: - data: Tuple that expands to: - a: torch.Tensor[float4e2m1fn] of shape [m, k, l], - b: torch.Tensor[float4e2m1fn] of shape [n, k, l], - sfa: torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l], - sfb: torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l], - sfa_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l], - sfb_permuted: torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l], - c: torch.Tensor[float16] of shape [m, n, l] - Returns: - Tensor containing output in float16 - c: torch.Tensor[float16] of shape [m, n, l] - """ - # c: [m, n, l] is pre-allocated memory to avoid timing allocation overhead. - a, b, sfa, sfb, sfa_permuted, sfb_permuted, c = data - - # Your implementation here - - return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemm/utils.py b/problems/nvidia/nvfp4_gemm/utils.py deleted file mode 100644 index d9b3a69e..00000000 --- a/problems/nvidia/nvfp4_gemm/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, - expected: torch.Tensor, - rtol=1e-05, - atol=1e-08, - max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") - return mismatch_details - - return [] - - -def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) - - if len(reasons) > 0: - return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - - return True, '' - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy diff --git a/problems/nvidia/nvfp4_group_gemm/eval.py b/problems/nvidia/nvfp4_gemv/eval.py similarity index 77% rename from problems/nvidia/nvfp4_group_gemm/eval.py rename to problems/nvidia/nvfp4_gemv/eval.py index 2f00f53d..ca325354 100644 --- a/problems/nvidia/nvfp4_group_gemm/eval.py +++ b/problems/nvidia/nvfp4_gemv/eval.py @@ -67,22 +67,13 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: tests = [] lines = content.splitlines() - # Match key: value pairs where value can be: - # - a list like [1, 2, 3] - # - a tuple like (1, 2, 3) - # - an integer - # - an alphabetic string - match = r"\s*([a-zA-Z_]+)\s*:\s*(\[[^\]]*\]|\([^)]*\)|[a-zA-Z_]+|[+-]?[0-9]+)\s*" + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" for line in lines: - if not line.strip(): - continue parts = line.split(";") case = {} for part in parts: - if not part.strip(): - continue - matched = re.fullmatch(match, part) - if not matched: + matched = re.match(match, part) + if not re.fullmatch(match, part): print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) exit(113) key = matched[1] @@ -90,16 +81,7 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: try: val = int(val) except ValueError: - # Try parsing as tuple/list - if (val.startswith('(') and val.endswith(')')) or (val.startswith('[') and val.endswith(']')): - try: - inner = val[1:-1].strip() - if inner: - val = tuple(int(x.strip()) for x in inner.split(',')) - else: - val = tuple() - except ValueError: - pass + pass case[key] = val tests.append(TestCase(spec=line, args=case)) @@ -193,6 +175,16 @@ def run_testing( @param tests: A list of TestCase objects representing the test cases to be executed. @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel passed = True logger.log("test-count", len(tests)) for idx, test in enumerate(tests): @@ -215,18 +207,46 @@ def run_testing( return 112 +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + try: + # Trigger compilation (will be cached) + compile_kernel() + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + def _run_single_benchmark( test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float ) -> Stats | Any: """ Runs one benchmark. Do not call directly. """ - from submission import custom_kernel + from submission import custom_kernel, compile_kernel durations = [] # generate input data once data = generate_input(**test.args) check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + compile_kernel() + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + # first, one obligatory correctness check try: output = custom_kernel(_clone_data(data)) @@ -237,7 +257,7 @@ def _run_single_benchmark( return message # now, do multiple timing runs without further correctness testing - # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # there is an upper bound of 200 runs, and a lower bound of 3 runs; # otherwise, we repeat until we either measure at least 10 full seconds, # or the relative error of the mean is below 1%. @@ -315,14 +335,24 @@ def run_benchmarking( @param tests: A list of TestCase objects representing the test cases to be benchmarked. @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. """ - # warm up - run_single_benchmark(pool, tests[0], False, 100, 10e7) + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + # Step 3: Run benchmarks (compilation time excluded) passed = True logger.log("benchmark-count", len(tests)) for idx, test in enumerate(tests): logger.log(f"benchmark.{idx}.spec", test.spec) - result = run_single_benchmark(pool, test, False, 100, 10e9) + result = run_single_benchmark(pool, test, False, 200, 10e9) if isinstance(result, Stats): for field in dataclasses.fields(Stats): logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) @@ -382,8 +412,34 @@ def main(): seed = int(seed) if seed else None set_seed(seed or 42) - # Parse test cases from temp file (text format from kernelbot) - tests = get_test_cases(sys.argv[2], seed) + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) with PopcornOutput(int(fd)) as logger: import multiprocessing @@ -396,12 +452,23 @@ def main(): return run_benchmarking(logger, pool, tests) if mode == "leaderboard": - # warmup - run_single_benchmark(pool, tests[0], False, 100, 1e7) + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) logger.log("benchmark-count", len(tests)) passed = True for i in range(len(tests)): - result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) logger.log(f"benchmark.{i}.spec", tests[i].spec) if isinstance(result, Stats): for field in dataclasses.fields(Stats): @@ -426,6 +493,4 @@ def main(): if __name__ == "__main__": - print("main") - main() - print("main end") \ No newline at end of file + sys.exit(main()) \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/reference.py b/problems/nvidia/nvfp4_gemv/reference.py index 8aeb41cf..b277d28b 100644 --- a/problems/nvidia/nvfp4_gemv/reference.py +++ b/problems/nvidia/nvfp4_gemv/reference.py @@ -160,10 +160,7 @@ def create_scale_factor_tensors(l, mn, sf_k): sfa_ref_cpu, sfa_permuted = create_scale_factor_tensors(l, m, sf_k) sfb_ref_cpu, sfb_permuted = create_scale_factor_tensors(l, n_padded_128, sf_k) - sfa_ref = sfa_ref_cpu.to("cuda") - sfb_ref = sfb_ref_cpu.to("cuda") - - return (a_ref, b_ref, sfa_ref, sfb_ref, sfa_permuted, sfb_permuted, c_ref) + return (a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, sfa_permuted, sfb_permuted, c_ref) check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_gemv/submission.py b/problems/nvidia/nvfp4_gemv/submission.py index 2db65cdd..56b815fc 100644 --- a/problems/nvidia/nvfp4_gemv/submission.py +++ b/problems/nvidia/nvfp4_gemv/submission.py @@ -1,54 +1,263 @@ import torch from task import input_t, output_t -# Kernel configuration parameters -sf_vec_size = 16 +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_ptr +import cutlass.utils.blockscaled_layout as blockscaled_utils +# Kernel configuration parameters +mma_tiler_mnk = (128, 1, 64) # Tile sizes for M, N, K dimensions +ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B +sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors +c_dtype = cutlass.Float16 # FP16 output type +sf_vec_size = 16 # Scale factor block size (16 elements share one scale) +threads_per_cta = 128 # Number of threads per CUDA thread block # Helper function for ceiling division def ceil_div(a, b): return (a + b - 1) // b -# Helper function to convert scale factor tensor to blocked format -def to_blocked(input_matrix): - rows, cols = input_matrix.shape +# The CuTe reference implementation for NVFP4 block-scaled GEMV +@cute.kernel +def kernel( + mA_mkl: cute.Tensor, + mB_nkl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, +): + # Get CUDA block and thread indices + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L]) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # Extract the local tile for scale factor tensor for A (same shape as gA_mkl) + # Here, block_M = (32, 4); block_K = (16, 4) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) + ) + # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L]) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # Extract the local tile for scale factor tensor for B (same shape as gB_nkl) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) + ) + # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L]) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) + ) + + # Select output element corresponding to this thread and block indices + tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] + tCgC = cute.make_tensor(tCgC.iterator, 1) + res = cute.zeros_like(tCgC, cutlass.Float32) + + # Get the number of k tiles (depth dimension) for the reduction loop + k_tile_cnt = gA_mkl.layout[3].shape + for k_tile in range(k_tile_cnt): + tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] + tBgB = gB_nkl[0, None, bidy, k_tile, bidz] + tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] + tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz] + + tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float32) + tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float32) + tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32) + tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32) + + # Load NVFP4 or FP8 values from global memory + a_val_nvfp4 = tAgA.load() + b_val_nvfp4 = tBgB.load() + sfa_val_fp8 = tAgSFA.load() + sfb_val_fp8 = tBgSFB.load() + + # Convert loaded values to float32 for computation (FFMA) + a_val = a_val_nvfp4.to(cutlass.Float32) + b_val = b_val_nvfp4.to(cutlass.Float32) + sfa_val = sfa_val_fp8.to(cutlass.Float32) + sfb_val = sfb_val_fp8.to(cutlass.Float32) + + # Store the converted values to RMEM CuTe tensors + tArA.store(a_val) + tBrB.store(b_val) + tArSFA.store(sfa_val) + tBrSFB.store(sfb_val) + + # Iterate over SF vector tiles and compute the scale&matmul accumulation + for i in cutlass.range_constexpr(mma_tiler_mnk[2]): + res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i] + + # Store the final float16 result back to global memory + tCgC.store(res.to(cutlass.Float16)) + return + - # Please ensure rows and cols are multiples of 128 and 4 respectively - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) +@cute.jit +def my_kernel( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + problem_size: tuple, +): + """ + Host-side JIT function to prepare tensors and launch GPU kernel. + """ + m, _, k, l = problem_size + # Create CuTe Tensor via pointer and problem size. + a_tensor = cute.make_tensor( + a_ptr, + cute.make_layout( + (m, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), + ), + ) + # We use n=128 to create the torch tensor to do fp4 computation via torch._scaled_mm + # then copy torch tensor to cute tensor for cute customize kernel computation + # therefore we need to ensure b_tensor has the right stride with this 128 padded size on n. + n_padded_128 = 128 + b_tensor = cute.make_tensor( + b_ptr, + cute.make_layout( + (n_padded_128, cute.assume(k, 32), l), + stride=(cute.assume(k, 32), 1, cute.assume(n_padded_128 * k, 32)), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, cute.make_layout((cute.assume(m, 32), 1, l), stride=(1, 1, m)) + ) + # Convert scale factor tensors to MMA layout + # The layout matches Tensor Core requirements: (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + # Compute grid dimensions + # Grid is (M_blocks, 1, L) where: + # - M_blocks = ceil(M / 128) to cover all output rows + # - L = batch size + grid = ( + cute.ceil_div(c_tensor.shape[0], 128), + 1, + c_tensor.shape[2], + ) + + # Launch the CUDA kernel + kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( + grid=grid, + block=[threads_per_cta, 1, 1], + cluster=(1, 1, 1), + ) + return + + +# Global cache for compiled kernel +_compiled_kernel_cache = None +# This function is used to compile the kernel once and cache it and then allow users to +# run the kernel multiple times to get more accurate timing results. +def compile_kernel(): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + - padded = input_matrix - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32 + ) - return rearranged.flatten() + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0)) + + return _compiled_kernel_cache -def custom_kernel( - data: input_t, -) -> output_t: +def custom_kernel(data: input_t) -> output_t: """ - PyTorch reference implementation of NVFP4 block-scaled GEMV. + Execute the block-scaled GEMV kernel. + + This is the main entry point called by the evaluation framework. + It converts PyTorch tensors to CuTe tensors, launches the kernel, + and returns the result. + + Args: + data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors + a: [m, k, l] - Input matrix in float4e2m1fn + b: [1, k, l] - Input vector in float4e2m1fn + sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn + sfb_cpu: [1, k, l] - Scale factors in float8_e4m3fn + sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn + sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn + c: [m, 1, l] - Output vector in float16 + + Returns: + Output tensor c with computed GEMV results """ - a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data - - # Get dimensions from MxNxL layout - _, _, l = c_ref.shape - - # Call torch._scaled_mm to compute the GEMV result - for l_idx in range(l): - # Convert the scale factor tensor to blocked format - scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx]) - scale_b = to_blocked(sfb_ref_cpu[:, :, l_idx]) - # (m, k) @ (n, k).T -> (m, n) - res = torch._scaled_mm( - a_ref[:, :, l_idx], - b_ref[:, :, l_idx].transpose(0, 1), - scale_a.cuda(), - scale_b.cuda(), - bias=None, - out_dtype=torch.float16, - ) - c_ref[:, 0, l_idx] = res[:, 0] - return c_ref + a, b, _, _, sfa_permuted, sfb_permuted, c = data + + # Ensure kernel is compiled (will use cached version if available) + # To avoid the compilation overhead, we compile the kernel once and cache it. + compiled_func = compile_kernel() + + # Get dimensions from MxKxL layout + m, k, l = a.shape + # Torch use e2m1_x2 data type, thus k is halved + k = k * 2 + # GEMV N dimension is always 1 + n = 1 + + # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer + a_ptr = make_ptr( + ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + b_ptr = make_ptr( + ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + c_ptr = make_ptr( + c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + sfa_ptr = make_ptr( + sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + + # Execute the compiled kernel + compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) + + return c \ No newline at end of file diff --git a/problems/nvidia/nvfp4_gemv/task.yml b/problems/nvidia/nvfp4_gemv/task.yml index 173914bd..756fe80d 100644 --- a/problems/nvidia/nvfp4_gemv/task.yml +++ b/problems/nvidia/nvfp4_gemv/task.yml @@ -37,7 +37,6 @@ config: templates: Python: "template.py" - CuteDSL: "template_cute.py" tests: - {"m": 128, "k": 256, "l": 1, "seed": 1111} diff --git a/problems/nvidia/nvfp4_gemv/template_cute.py b/problems/nvidia/nvfp4_gemv/template_cute.py deleted file mode 100644 index 1eebc8f7..00000000 --- a/problems/nvidia/nvfp4_gemv/template_cute.py +++ /dev/null @@ -1,247 +0,0 @@ -import torch -from task import input_t, output_t - -import cutlass -import cutlass.cute as cute -from cutlass.cute.runtime import make_ptr -import cutlass.utils.blockscaled_layout as blockscaled_utils - -# Kernel configuration parameters -mma_tiler_mnk = (128, 1, 64) # Tile sizes for M, N, K dimensions -ab_dtype = cutlass.Float4E2M1FN # FP4 data type for A and B -sf_dtype = cutlass.Float8E4M3FN # FP8 data type for scale factors -c_dtype = cutlass.Float16 # FP16 output type -sf_vec_size = 16 # Scale factor block size (16 elements share one scale) -threads_per_cta = 128 # Number of threads per CUDA thread block - - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# The CuTe reference implementation for NVFP4 block-scaled GEMV -@cute.kernel -def kernel( - mA_mkl: cute.Tensor, - mB_nkl: cute.Tensor, - mSFA_mkl: cute.Tensor, - mSFB_nkl: cute.Tensor, - mC_mnl: cute.Tensor, -): - # Get CUDA block and thread indices - bidx, bidy, bidz = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - - # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L]) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # Extract the local tile for scale factor tensor for A (same shape as gA_mkl) - # Here, block_M = (32, 4); block_K = (16, 4) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L]) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # Extract the local tile for scale factor tensor for B (same shape as gB_nkl) - gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L]) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None) - ) - - # Select output element corresponding to this thread and block indices - tCgC = gC_mnl[tidx, None, bidx, bidy, bidz] - tCgC = cute.make_tensor(tCgC.iterator, 1) - res = cute.zeros_like(tCgC, cutlass.Float32) - - # Get the number of k tiles (depth dimension) for the reduction loop - k_tile_cnt = gA_mkl.layout[3].shape - for k_tile in range(k_tile_cnt): - tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz] - tBgB = gB_nkl[0, None, bidy, k_tile, bidz] - tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz] - tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz] - - tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float32) - tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float32) - tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32) - tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32) - - # Load NVFP4 or FP8 values from global memory - a_val_nvfp4 = tAgA.load() - b_val_nvfp4 = tBgB.load() - sfa_val_fp8 = tAgSFA.load() - sfb_val_fp8 = tBgSFB.load() - - # Convert loaded values to float32 for computation (FFMA) - a_val = a_val_nvfp4.to(cutlass.Float32) - b_val = b_val_nvfp4.to(cutlass.Float32) - sfa_val = sfa_val_fp8.to(cutlass.Float32) - sfb_val = sfb_val_fp8.to(cutlass.Float32) - - # Store the converted values to RMEM CuTe tensors - tArA.store(a_val) - tBrB.store(b_val) - tArSFA.store(sfa_val) - tBrSFB.store(sfb_val) - - # Iterate over SF vector tiles and compute the scale&matmul accumulation - for i in cutlass.range_constexpr(mma_tiler_mnk[2]): - res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i] - - # Store the final float16 result back to global memory - tCgC.store(res.to(cutlass.Float16)) - return - - -@cute.jit -def my_kernel( - a_ptr: cute.Pointer, - b_ptr: cute.Pointer, - sfa_ptr: cute.Pointer, - sfb_ptr: cute.Pointer, - c_ptr: cute.Pointer, - problem_size: tuple, -): - """ - Host-side JIT function to prepare tensors and launch GPU kernel. - """ - m, _, k, l = problem_size - # Create CuTe Tensor via pointer and problem size. - a_tensor = cute.make_tensor( - a_ptr, - cute.make_layout( - (m, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32)), - ), - ) - # We use n=128 to create the torch tensor to do fp4 computation via torch._scaled_mm - # then copy torch tensor to cute tensor for cute customize kernel computation - # therefore we need to ensure b_tensor has the right stride with this 128 padded size on n. - n_padded_128 = 128 - b_tensor = cute.make_tensor( - b_ptr, - cute.make_layout( - (n_padded_128, cute.assume(k, 32), l), - stride=(cute.assume(k, 32), 1, cute.assume(n_padded_128 * k, 32)), - ), - ) - c_tensor = cute.make_tensor( - c_ptr, cute.make_layout((cute.assume(m, 32), 1, l), stride=(1, 1, m)) - ) - # Convert scale factor tensors to MMA layout - # The layout matches Tensor Core requirements: (((32, 4), REST_M), ((SF_K, 4), REST_K), (1, REST_L)) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, sf_vec_size) - sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) - - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, sf_vec_size) - sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) - - # Compute grid dimensions - # Grid is (M_blocks, 1, L) where: - # - M_blocks = ceil(M / 128) to cover all output rows - # - L = batch size - grid = ( - cute.ceil_div(c_tensor.shape[0], 128), - 1, - c_tensor.shape[2], - ) - - # Launch the CUDA kernel - kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch( - grid=grid, - block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1), - ) - return - - -# Global cache for compiled kernel -_compiled_kernel_cache = None - - -# This function is used to compile the kernel once and cache it and then allow users to -# run the kernel multiple times to get more accurate timing results. -def compile_kernel(): - """ - Compile the kernel once and cache it. - This should be called before any timing measurements. - - Returns: - The compiled kernel function - """ - global _compiled_kernel_cache - - if _compiled_kernel_cache is not None: - return _compiled_kernel_cache - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) - b_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) - c_ptr = make_ptr(c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) - sfa_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32) - sfb_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32) - - # Compile the kernel - _compiled_kernel_cache = cute.compile( - my_kernel, a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (0, 0, 0, 0) - ) - - return _compiled_kernel_cache - - -def custom_kernel(data: input_t) -> output_t: - """ - Execute the block-scaled GEMV kernel. - - This is the main entry point called by the evaluation framework. - It converts PyTorch tensors to CuTe tensors, launches the kernel, - and returns the result. - - Args: - data: Tuple of (a, b, sfa_cpu, sfb_cpu, c) PyTorch tensors - a: [m, k, l] - Input matrix in float4e2m1fn - b: [1, k, l] - Input vector in float4e2m1fn - sfa_cpu: [m, k, l] - Scale factors in float8_e4m3fn - sfb_cpu: [1, k, l] - Scale factors in float8_e4m3fn - sfa_permuted: [32, 4, rest_m, 4, rest_k, l] - Scale factors in float8_e4m3fn - sfb_permuted: [32, 4, rest_n, 4, rest_k, l] - Scale factors in float8_e4m3fn - c: [m, 1, l] - Output vector in float16 - - Returns: - Output tensor c with computed GEMV results - """ - a, b, _, _, sfa_permuted, sfb_permuted, c = data - - # Ensure kernel is compiled (will use cached version if available) - # To avoid the compilation overhead, we compile the kernel once and cache it. - compiled_func = compile_kernel() - - # Get dimensions from MxKxL layout - m, k, l = a.shape - # Torch use e2m1_x2 data type, thus k is halved - k = k * 2 - # GEMV N dimension is always 1 - n = 1 - - # Create CuTe pointers for A/B/C/SFA/SFB via torch tensor data pointer - a_ptr = make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - b_ptr = make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - c_ptr = make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) - sfa_ptr = make_ptr( - sf_dtype, sfa_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - sfb_ptr = make_ptr( - sf_dtype, sfb_permuted.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) - - # Execute the compiled kernel - compiled_func(a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l)) - - return c diff --git a/problems/helion/utils.py b/problems/nvidia/nvfp4_gemv/utils.py similarity index 100% rename from problems/helion/utils.py rename to problems/nvidia/nvfp4_gemv/utils.py diff --git a/problems/nvidia/nvfp4_group_gemm/reference.py b/problems/nvidia/nvfp4_group_gemm/reference.py deleted file mode 100644 index 7ce6bc09..00000000 --- a/problems/nvidia/nvfp4_group_gemm/reference.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -from task import input_t, output_t -from utils import make_match_reference - -# Scaling factor vector size -sf_vec_size = 16 - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# Helper function to convert scale factor tensor to blocked format -def to_blocked(input_matrix): - rows, cols = input_matrix.shape - - # Please ensure rows and cols are multiples of 128 and 4 respectively - n_row_blocks = ceil_div(rows, 128) - n_col_blocks = ceil_div(cols, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - # Pad the input matrix if necessary - if padded_rows != rows or padded_cols != cols: - padded = torch.nn.functional.pad( - input_matrix, - (0, padded_cols - cols, 0, padded_rows - rows), - mode="constant", - value=0, - ) - else: - padded = input_matrix - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) - rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - - return rearranged.flatten() - - -def ref_kernel( - data: input_t, -) -> output_t: - """ - PyTorch reference implementation of NVFP4 block-scaled group GEMM. - """ - abc_tensors, sfasfb_tensors, _, problem_sizes = data - - result_tensors = [] - for i, ( - (a_ref, b_ref, c_ref), - (sfa_ref, sfb_ref), - (m, n, k, l), - ) in enumerate( - zip( - abc_tensors, - sfasfb_tensors, - problem_sizes, - ) - ): - for l_idx in range(l): - # Convert the scale factor tensor to blocked format - scale_a = to_blocked(sfa_ref[:, :, l_idx]) - scale_b = to_blocked(sfb_ref[:, :, l_idx]) - # (m, k) @ (n, k).T -> (m, n) - res = torch._scaled_mm( - a_ref[:, :, l_idx].view(torch.float4_e2m1fn_x2), - b_ref[:, :, l_idx].transpose(0, 1).view(torch.float4_e2m1fn_x2), - scale_a.cuda(), - scale_b.cuda(), - bias=None, - out_dtype=torch.float16, - ) - c_ref[:, :, l_idx] = res - result_tensors.append((c_ref)) - return result_tensors - - -# Helper function to prepare the scale factor tensors for both reference -# kernel and customize kernel. The customized data layout can be found in: -# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout -def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor): - sf_k = ceil_div(k, sf_vec_size) - atom_m = (32, 4) - atom_k = 4 - mma_shape = ( - l, # batch size - ceil_div(mn, atom_m[0] * atom_m[1]), - ceil_div(sf_k, atom_k), - atom_m[0], - atom_m[1], - atom_k, - ) - # Create the reordered scale factor tensor (32, 4, rest_m, 4, rest_k, l) on GPU. - mma_permute_order = (3, 4, 1, 5, 2, 0) - # Generate a random int8 tensor, then convert to float8_e4m3fn - rand_int_tensor = torch.randint(1, 3, mma_shape, dtype=torch.int8, device='cuda') - reordered_f8_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn) - # Permute according to mma_permute_order - reordered_f8_tensor = reordered_f8_tensor.permute(*mma_permute_order) - - # Move ref_f8_tensor to GPU if not already there - if ref_f8_tensor.device.type == 'cpu': - ref_f8_tensor = ref_f8_tensor.cuda() - - # GPU-side vectorized reordering (replaces slow CPU nested loops) - # Create index grids for all dimensions - i_idx = torch.arange(mn, device='cuda') - j_idx = torch.arange(sf_k, device='cuda') - b_idx = torch.arange(l, device='cuda') - - # Create meshgrid for all combinations of (i, j, b) - i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij') - - # Calculate target indices in vectorized manner - mm = i_grid // (atom_m[0] * atom_m[1]) - mm32 = i_grid % atom_m[0] - mm4 = (i_grid % 128) // atom_m[0] - kk = j_grid // atom_k - kk4 = j_grid % atom_k - - # Perform the reordering with advanced indexing (all on GPU) - reordered_f8_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_tensor[i_grid, j_grid, b_grid] - - return reordered_f8_tensor - - -def _create_fp4_tensors(l, mn, k): - # generate uint8 tensor, then convert to float4e2m1fn_x2 data type - # generate all bit patterns - ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda") - - # for each nibble, only keep the sign bit and 2 LSBs - # the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5] - ref_i8 = ref_i8 & 0b1011_1011 - return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2) - - -def generate_input( - m: tuple, - n: tuple, - k: tuple, - g: int, - seed: int, -): - """ - Generate input tensors for NVFP4 block-scaled group GEMM. - Each group can have different m, n, k, l. - - Args: - problem_sizes: List of tuples (m, n, k, l) for each problem - m: Number of rows in matrix A - n: Number of columns in matrix B - k: Number of columns in A and rows of B - l: Batch size, always is 1 - groups: Number of groups - seed: Random seed for reproducibility - - Returns: - Tuple of (list(tuple(a, b, c)), list(tuple(sfa, sfb)), list(tuple(sfa_reordered, sfb_reordered)), list(tuple(m, n, k, l))) where each group has its own a, b, c, sfa, sfb. - a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - b: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type - sfa: [m, k // 16, l] - Input scale factors in torch.float8e4m3fn data type - sfb: [n, k // 16, l] - Input scale factors in torch.float8e4m3fn data type - sfa_reordered: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - sfb_reordered: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type - c: [m, n, l] - Output matrix in torch.float16 data type - """ - torch.manual_seed(seed) - - abc_tensors = [] - sfasfb_tensors = [] - sfasfb_reordered_tensors = [] - problem_sizes = [] - l = 1 - # Generate a, b, c, sfa, sfb tensors for all groups - for group_idx in range(g): - mi = m[group_idx] - ni = n[group_idx] - ki = k[group_idx] - a_ref = _create_fp4_tensors(l, mi, ki) - b_ref = _create_fp4_tensors(l, ni, ki) - - c_ref = torch.randn((l, mi, ni), dtype=torch.float16, device="cuda").permute( - 1, 2, 0 - ) - - sf_k = ceil_div(ki, sf_vec_size) - sfa_ref_cpu = torch.randint( - 1, 3, (l, mi, sf_k), dtype=torch.int8 - ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) - sfb_ref_cpu = torch.randint( - 1, 3, (l, ni, sf_k), dtype=torch.int8 - ).to(dtype=torch.float8_e4m3fn).permute(1, 2, 0) - - sfa_reordered = create_reordered_scale_factor_tensor(l, mi, ki, sfa_ref_cpu) - sfb_reordered = create_reordered_scale_factor_tensor(l, ni, ki, sfb_ref_cpu) - - abc_tensors.append((a_ref, b_ref, c_ref)) - sfasfb_tensors.append((sfa_ref_cpu, sfb_ref_cpu)) - sfasfb_reordered_tensors.append((sfa_reordered, sfb_reordered)) - problem_sizes.append((mi, ni, ki, l)) - return (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) - - -check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03) diff --git a/problems/nvidia/nvfp4_group_gemm/submission.py b/problems/nvidia/nvfp4_group_gemm/submission.py deleted file mode 100644 index 439a0ea2..00000000 --- a/problems/nvidia/nvfp4_group_gemm/submission.py +++ /dev/null @@ -1,1059 +0,0 @@ -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -import cutlass.utils.blackwell_helpers as sm100_utils -import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass.cute.runtime import make_ptr - -import functools -from typing import Tuple, List - -import torch -from task import input_t, output_t - -# Kernel configuration parameters -# Size of tma descriptor in bytes -bytes_per_tensormap = 128 -# Number of tensormaps: a, b, sfa, sfb -num_tensormaps = 4 -# Tile sizes for M, N, K dimensions -mma_tiler_mnk = (128, 128, 256) -# Shape of the K dimension for the MMA instruction -mma_inst_shape_k = 64 -# FP4 data type for A and B -ab_dtype = cutlass.Float4E2M1FN -# FP8 data type for scale factors -sf_dtype = cutlass.Float8E4M3FN -# FP16 output type -c_dtype = cutlass.Float16 -# Scale factor block size (16 elements share one scale) -sf_vec_size = 16 -# Number of threads per CUDA thread block -threads_per_cta = 128 -# Stage numbers of shared memory and tmem -num_acc_stage = 1 -num_ab_stage = 1 -# Total number of columns in tmem -num_tmem_alloc_cols = 512 - - -# Helper function for ceiling division -def ceil_div(a, b): - return (a + b - 1) // b - - -# The CuTe reference implementation for NVFP4 block-scaled GEMM -@cute.kernel -def kernel( - tiled_mma: cute.TiledMma, - tma_atom_a: cute.CopyAtom, - mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, - tma_atom_sfa: cute.CopyAtom, - mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, - tensor_of_abc_ptrs: cute.Tensor, - tensor_of_sfasfb_ptrs: cute.Tensor, - tensormaps: cute.Tensor, - tensor_of_problem_sizes: cute.Tensor, - a_smem_layout_staged: cute.ComposedLayout, - b_smem_layout_staged: cute.ComposedLayout, - sfa_smem_layout_staged: cute.Layout, - sfb_smem_layout_staged: cute.Layout, - cta_mn_list: List[Tuple[int, int]], - num_tma_load_bytes: cutlass.Constexpr[int], -): - """ - GPU device kernel performing the Group GEMM computation. - """ - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - tidx, _, _ = cute.arch.thread_idx() - - # - # Delinearize bidz to coord_x, coord_y and group_idx for each CTA - # - bidx, bidy, bidz = cute.arch.block_idx() - group_idx = 0 - find = False - coord_x = 0 - coord_y = 0 - cta_rest = bidz - for _, (cta_m, cta_n) in enumerate(cta_mn_list): - if cta_rest >= (cta_m * cta_n): - group_idx += 1 - cta_rest -= cta_m * cta_n - else: - if not find: - coord_y = cta_rest // cta_m - coord_x = cta_rest % cta_m - cta_rest -= cta_m * cta_n - find = True - - # - # Construct C Tensor for each CTA - # - mC_mnl_iter = cute.make_ptr( - c_dtype, tensor_of_abc_ptrs[group_idx, 2], cute.AddressSpace.gmem - ).align(32) - m = tensor_of_problem_sizes[group_idx, 0] - n = tensor_of_problem_sizes[group_idx, 1] - k = tensor_of_problem_sizes[group_idx, 2] - l = tensor_of_problem_sizes[group_idx, 3] - - mC_mnl_layout = cute.make_layout( - (m, n, l), - stride=(cute.assume(n, 32), 1, cute.assume(m * n, 32),)) - mC_mnl = cute.make_tensor(mC_mnl_iter, mC_mnl_layout) - # Local partition for global C Tensor - # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (coord_x, coord_y, 0) - ) - - # - # Define shared storage for kernel - # - size_tensormap_in_i64 = ( - num_tensormaps * bytes_per_tensormap // 8 - ) - @cute.struct - class SharedStorage: - tensormap_buffer: cute.struct.MemRange[ - cutlass.Int64, size_tensormap_in_i64 - ] - ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_ab_stage * 2] - acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, num_acc_stage * 2] - tmem_holding_buf: cutlass.Int32 - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() - tensormap_a_smem_ptr = tensormap_smem_ptr - tensormap_b_smem_ptr = ( - tensormap_a_smem_ptr - + bytes_per_tensormap // 8 - ) - tensormap_sfa_smem_ptr = ( - tensormap_b_smem_ptr - + bytes_per_tensormap // 8 - ) - tensormap_sfb_smem_ptr = ( - tensormap_sfa_smem_ptr - + bytes_per_tensormap // 8 - ) - # Setup smem tensor for A, B, SFA, SFB - # (MMA, MMA_M, MMA_K, STAGE) - sA = smem.allocate_tensor( - element_type=ab_dtype, - layout=a_smem_layout_staged.outer, - byte_alignment=128, - swizzle=a_smem_layout_staged.inner, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sB = smem.allocate_tensor( - element_type=ab_dtype, - layout=b_smem_layout_staged.outer, - byte_alignment=128, - swizzle=b_smem_layout_staged.inner, - ) - # (MMA, MMA_M, MMA_K, STAGE) - sSFA = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfa_smem_layout_staged, - byte_alignment=128, - ) - # (MMA, MMA_N, MMA_K, STAGE) - sSFB = smem.allocate_tensor( - element_type=sf_dtype, - layout=sfb_smem_layout_staged, - byte_alignment=128, - ) - - # Initialize mainloop ab_pipeline, acc_pipeline and their states - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.ab_mbar_ptr.data_ptr(), - num_stages=num_ab_stage, - producer_group=ab_pipeline_producer_group, - consumer_group=ab_pipeline_consumer_group, - tx_count=num_tma_load_bytes, - ).make_participants() - acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( - barrier_storage=storage.acc_mbar_ptr.data_ptr(), - num_stages=num_acc_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), - consumer_group=pipeline.CooperativeGroup( - pipeline.Agent.Thread, - threads_per_cta, - ), - ).make_participants() - - # - # Local_tile partition global tensors - # - # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # (bM, bK, RestM, RestK, RestL) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None) - ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None) - ) - # - # Partition global tensor for TiledMMA_A/B/C - # - thr_mma = tiled_mma.get_slice(tidx) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) - tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma.partition_B(gSFB_nkl) - # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) - tCgC = thr_mma.partition_C(gC_mnl) - - # Update tma descriptor with the correct shapes and strides - tensormap_manager = utils.TensorMapManager( - utils.TensorMapUpdateMode.SMEM, - 128, - ) - tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( - tensormaps[(bidz, 0, None)].iterator - ) - tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( - tensormaps[(bidz, 1, None)].iterator - ) - tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( - tensormaps[(bidz, 2, None)].iterator - ) - tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( - tensormaps[(bidz, 3, None)].iterator - ) - - mA_mkl_iter = cute.make_ptr( - ab_dtype, tensor_of_abc_ptrs[group_idx, 0], cute.AddressSpace.gmem - ).align(32) - mB_nkl_iter = cute.make_ptr( - ab_dtype, tensor_of_abc_ptrs[group_idx, 1], cute.AddressSpace.gmem - ).align(32) - sfa_mkl_iter = cute.make_ptr( - sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 0], cute.AddressSpace.gmem - ).align(32) - sfb_nkl_iter = cute.make_ptr( - sf_dtype, tensor_of_sfasfb_ptrs[group_idx, 1], cute.AddressSpace.gmem - ).align(32) - mA_mkl_layout = cute.make_layout( - (m, k, l), stride=(cute.assume(k, 32), 1, cute.assume(m * k, 32),)) - mB_nkl_layout = cute.make_layout( - (n, k, l), stride=(cute.assume(k, 32), 1, cute.assume(n * k, 32),)) - - # SFA, SFB follows specialized layout defined in the following link: - # https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout - atom_shape = ((32, 4), (sf_vec_size, 4)) - atom_stride = ((16, 4), (0, 1)) - sfa_layout = cute.tile_to_shape( - cute.make_layout(atom_shape, stride=atom_stride), - mA_mkl_layout.shape, - (2, 1, 3), - ) - sfb_layout = cute.tile_to_shape( - cute.make_layout(atom_shape, stride=atom_stride), - mB_nkl_layout.shape, - (2, 1, 3), - ) - real_tensor_a = cute.make_tensor(mA_mkl_iter, mA_mkl_layout) - real_tensor_b = cute.make_tensor(mB_nkl_iter, mB_nkl_layout) - real_tensor_sfa = cute.make_tensor(sfa_mkl_iter, sfa_layout) - real_tensor_sfb = cute.make_tensor(sfb_nkl_iter, sfb_layout) - - # Let warp 0 initialize tensormap - if warp_idx == 0: - tensormap_manager.init_tensormap_from_atom( - tma_atom_a, tensormap_a_smem_ptr, 0 - ) - tensormap_manager.init_tensormap_from_atom( - tma_atom_b, tensormap_b_smem_ptr, 0 - ) - tensormap_manager.init_tensormap_from_atom( - tma_atom_sfa, tensormap_sfa_smem_ptr, 0 - ) - tensormap_manager.init_tensormap_from_atom( - tma_atom_sfb, tensormap_sfb_smem_ptr, 0 - ) - tensormap_manager.update_tensormap( - ( - real_tensor_a, - real_tensor_b, - real_tensor_sfa, - real_tensor_sfb, - ), - (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), - ( - tensormap_a_gmem_ptr, - tensormap_b_gmem_ptr, - tensormap_sfa_gmem_ptr, - tensormap_sfb_gmem_ptr, - ), - 0, # tma warp id - ( - tensormap_a_smem_ptr, - tensormap_b_smem_ptr, - tensormap_sfa_smem_ptr, - tensormap_sfb_smem_ptr, - ), - ) - - tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) - tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) - tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) - tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) - - cute.arch.barrier() - - # - # Partition global/shared tensor for TMA load A/B/SFA/SFB - # - # TMA Partition_S/D for A - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsA, tAgA = cpasync.tma_partition( - tma_atom_a, - 0, - cute.make_layout(1), - cute.group_modes(sA, 0, 3), - cute.group_modes(tCgA, 0, 3), - ) - # TMA Partition_S/D for B - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, - 0, - cute.make_layout(1), - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) - # TMA Partition_S/D for SFA - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestM, RestK, RestL) - tAsSFA, tAgSFA = cpasync.tma_partition( - tma_atom_sfa, - 0, - cute.make_layout(1), - cute.group_modes(sSFA, 0, 3), - cute.group_modes(tCgSFA, 0, 3), - ) - tAsSFA = cute.filter_zeros(tAsSFA) - tAgSFA = cute.filter_zeros(tAgSFA) - # TMA Partition_S/D for SFB - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cpasync.tma_partition( - tma_atom_sfb, - 0, - cute.make_layout(1), - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) - - # - # Partition shared/tensor memory tensor for TiledMMA_A/B/C - # - # (MMA, MMA_M, MMA_K, STAGE) - tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) - tCrB = tiled_mma.make_fragment_B(sB) - # (MMA, MMA_M, MMA_N) - acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) - # (MMA, MMA_M, MMA_N) - tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) - # - # Alloc tensor memory buffer - # - tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=1, - num_threads=threads_per_cta, - ) - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=tmem_alloc_barrier, - ) - tmem.allocate(num_tmem_alloc_cols) - tmem.wait_for_alloc() - acc_tmem_ptr = tmem.retrieve_ptr(cutlass.Float32) - tCtAcc = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) - - # - # Make SFA/SFB tmem tensor - # - # Get SFA tmem ptr - sfa_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc), - dtype=sf_dtype, - ) - # (MMA, MMA_M, MMA_K) - tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), - ) - tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) - # Get SFB tmem ptr - sfb_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), - dtype=sf_dtype, - ) - # (MMA, MMA_N, MMA_K) - tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), - ) - tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) - - # - # Partition for S2T copy of SFA/SFB - # - # Make S2T CopyAtom - copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), - sf_dtype, - ) - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFA_compact = cute.filter_zeros(sSFA) - tCtSFA_compact = cute.filter_zeros(tCtSFA) - tiled_copy_s2t_sfa = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFA_compact) - thr_copy_s2t_sfa = tiled_copy_s2t_sfa.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t_ = thr_copy_s2t_sfa.partition_S(tCsSFA_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFA_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfa, tCsSFA_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFA_compact_s2t = thr_copy_s2t_sfa.partition_D(tCtSFA_compact) - - # (MMA, MMA_MN, MMA_K, STAGE) - tCsSFB_compact = cute.filter_zeros(sSFB) - # (MMA, MMA_MN, MMA_K) - tCtSFB_compact = cute.filter_zeros(tCtSFB) - tiled_copy_s2t_sfb = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSFB_compact) - thr_copy_s2t_sfb = tiled_copy_s2t_sfb.get_slice(0) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB_compact_s2t_ = thr_copy_s2t_sfb.partition_S(tCsSFB_compact) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSFB_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t_sfb, tCsSFB_compact_s2t_ - ) - # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) - tCtSFB_compact_s2t = thr_copy_s2t_sfb.partition_D(tCtSFB_compact) - - # Number of K loops - k_tile_cnt = cute.ceil_div(real_tensor_a.shape[1], mma_tiler_mnk[2]) - - # - # Slice to per mma tile index - # - mma_tile_coord_mnl = (coord_x, coord_y, 0) - # ((atom_v, rest_v), RestK) - tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tAgSFA = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] - # ((atom_v, rest_v), RestK) - tBgSFB = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] - - # - # Main loop - # - if warp_idx == 0: - # Wait for accumulator buffer empty - acc_empty = acc_producer.acquire_and_advance() - # Set ACCUMULATE field to False for the first k_tile iteration - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - # Execute k_tile loop - for k_tile in range(k_tile_cnt): - # Wait for AB buffer empty - ab_empty = ab_producer.acquire_and_advance() - - # TMA load A/B/SFA/SFB to shared memory - cute.copy( - tma_atom_a, - tAgA[(None, k_tile)], - tAsA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - tma_desc_ptr=tensormap_manager.get_tensormap_ptr( - tensormap_a_gmem_ptr, - cute.AddressSpace.generic, - ), - ) - cute.copy( - tma_atom_b, - tBgB[(None, k_tile)], - tBsB[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - tma_desc_ptr=tensormap_manager.get_tensormap_ptr( - tensormap_b_gmem_ptr, - cute.AddressSpace.generic, - ), - ) - cute.copy( - tma_atom_sfa, - tAgSFA[(None, k_tile)], - tAsSFA[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - tma_desc_ptr=tensormap_manager.get_tensormap_ptr( - tensormap_sfa_gmem_ptr, - cute.AddressSpace.generic, - ), - ) - cute.copy( - tma_atom_sfb, - tBgSFB[(None, k_tile)], - tBsSFB[(None, ab_empty.index)], - tma_bar_ptr=ab_empty.barrier, - tma_desc_ptr=tensormap_manager.get_tensormap_ptr( - tensormap_sfb_gmem_ptr, - cute.AddressSpace.generic, - ), - ) - - # Wait for AB buffer full - ab_full = ab_consumer.wait_and_advance() - - # Copy SFA/SFB from shared memory to TMEM - s2t_stage_coord = (None, None, None, None, ab_full.index) - tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] - tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] - cute.copy( - tiled_copy_s2t_sfa, - tCsSFA_compact_s2t_staged, - tCtSFA_compact_s2t, - ) - cute.copy( - tiled_copy_s2t_sfb, - tCsSFB_compact_s2t_staged, - tCtSFB_compact_s2t, - ) - - # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = ( - None, - None, - kblock_idx, - ab_full.index, - ) - - # Set SFA/SFB tensor to tiled_mma - sf_kblock_coord = (None, None, kblock_idx) - tiled_mma.set( - tcgen05.Field.SFA, - tCtSFA[sf_kblock_coord].iterator, - ) - tiled_mma.set( - tcgen05.Field.SFB, - tCtSFB[sf_kblock_coord].iterator, - ) - - cute.gemm( - tiled_mma, - tCtAcc, - tCrA[kblock_coord], - tCrB[kblock_coord], - tCtAcc, - ) - # Enable accumulate on tCtAcc after first kblock - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - - # Async arrive AB buffer empty - ab_full.release() - acc_empty.commit() - - # - # Epilogue - # Partition for epilogue - # - op = tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE) - copy_atom_t2r = cute.make_copy_atom(op, cutlass.Float32) - tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tCtAcc[None,0,0]) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - # (TmemCpy, NumTmemCpy) - tDtAcc = thr_copy_t2r.partition_S(tCtAcc[None,0,0]) - # (TmemCpy, NumTmemCpy) - tDgC = thr_copy_t2r.partition_D(tCgC[None,0,0]) - - # (TmemCpy, NumTmemCpy) - tDrAcc = cute.make_rmem_tensor(tDgC.shape, cutlass.Float32) - # (TmemCpy, NumTmemCpy) - tDrC = cute.make_rmem_tensor(tDgC.shape, c_dtype) - - # Release TMEM allocation lock - tmem.relinquish_alloc_permit() - # Wait for accumulator buffer full - acc_full = acc_consumer.wait_and_advance() - - # Copy accumulator to register - cute.copy(tiled_copy_t2r, tDtAcc, tDrAcc) - acc_vec = tDrAcc.load() - tDrC.store(acc_vec.to(c_dtype)) - - # STG Atom, just to ensure functionality - # For performance optimization, better to use Tma store operation to - # reduce address calculation and predicate calulation instructions - simt_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), c_dtype, num_bits_per_copy=16 - ) - thread_layout = cute.make_layout( - (1, threads_per_cta), stride=(threads_per_cta, 1)) - value_layout = cute.make_layout((1, 1)) - tiled_copy_r2g = cute.make_tiled_copy_tv( - simt_atom, thread_layout, value_layout - ) - thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) - cC = cute.make_identity_tensor(gC_mnl.shape) - # ((atom_v, rest_v), NumGmemCpy) - tDcC = thr_copy_r2g.partition_D(cC) - - # ((atom_v, rest_v), NumGmemCpy) - tDpC = cute.make_rmem_tensor(tDrC.shape, cutlass.Boolean) - residue_m = mC_mnl.shape[0] - cutlass.Int32(coord_x) * mma_tiler_mnk[0] - residue_n = mC_mnl.shape[1] - cutlass.Int32(coord_y) * mma_tiler_mnk[1] - for i in range(cute.size(tDrC.shape)): - # Swap residue_m and residue_n to match the order of tDcC - tDpC[i] = cute.elem_less(tDcC[i], (residue_n, residue_m)) - cute.copy(simt_atom, cute.flatten(tDrC), cute.flatten(tDgC), pred=cute.flatten(tDpC)) - - acc_full.release() - # Deallocate TMEM - cute.arch.barrier() - tmem.free(acc_tmem_ptr) - pass - - -# Host-side JIT function to prepare tensors and launch GPU kernel. -@cute.jit -def my_kernel( - ptr_of_tensor_of_problem_sizes: cute.Pointer, - ptr_of_tensor_of_abc_ptrs: cute.Pointer, - ptr_of_tensor_of_sfasfb_ptrs: cute.Pointer, - ptr_of_tensor_of_tensormap: cute.Pointer, - total_num_clusters: cutlass.Int32, - problem_sizes: List[ - Tuple[int, int, int, int] - ], # Problem sizes for each group - num_groups: cutlass.Int32, -): - - tensor_of_abc_ptrs = cute.make_tensor( - ptr_of_tensor_of_abc_ptrs, cute.make_layout((num_groups, 3), stride=(3, 1)) - ) - tensor_of_sfasfb_ptrs = cute.make_tensor( - ptr_of_tensor_of_sfasfb_ptrs, cute.make_layout((num_groups, 2), stride=(2, 1)) - ) - tensor_of_problem_sizes = cute.make_tensor( - ptr_of_tensor_of_problem_sizes, cute.make_layout((num_groups, 4), stride=(4, 1)) - ) - tensor_of_tensormap = cute.make_tensor( - ptr_of_tensor_of_tensormap, cute.make_layout((total_num_clusters, 4, 16), stride=(64, 16, 1)) - ) - - # Use fake shape for initial Tma descriptor and atom setup - # The real Tma desc and atom will be updated during kernel execution. - min_a_shape = (cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(1)) - min_b_shape = (cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(64), cutlass.Int32(1)) - initial_a = cute.make_tensor( - cute.make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), - cute.make_layout( - (min_a_shape[0], cute.assume(min_a_shape[2], 32), min_a_shape[3]), - stride=( - cute.assume(min_a_shape[2], 32), - 1, - cute.assume(min_a_shape[0] * min_a_shape[2], 32), - ), - ), - ) - initial_b = cute.make_tensor( - cute.make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), - cute.make_layout( - (min_b_shape[1], cute.assume(min_b_shape[2], 32), min_b_shape[3]), - stride=( - cute.assume(min_b_shape[2], 32), - 1, - cute.assume(min_b_shape[1] * min_b_shape[2], 32), - ), - ), - ) - - # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout - # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - initial_a.shape, sf_vec_size - ) - # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - initial_b.shape, sf_vec_size - ) - # Create initial SFA and SFB tensors with fake shape and null pointer. - initial_sfa = cute.make_tensor( - cute.make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), sfa_layout) - initial_sfb = cute.make_tensor( - cute.make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=16,), sfb_layout) - - # Select MMA operation - mma_op = tcgen05.MmaMXF4NVF4Op( - sf_dtype, - (mma_tiler_mnk[0], mma_tiler_mnk[1], mma_inst_shape_k), - tcgen05.CtaGroup.ONE, - tcgen05.OperandSource.SMEM, - ) - tiled_mma = cute.make_tiled_mma(mma_op) - - cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((1, 1, 1)), - (tiled_mma.thr_id.shape,), - ) - - # Compute A/B/SFA/SFB/C shared memory layout - a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, - mma_tiler_mnk, - ab_dtype, - num_ab_stage, - ) - sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, - mma_tiler_mnk, - sf_vec_size, - num_ab_stage, - ) - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - # Setup TMA for A - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - initial_a, - a_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - ) - # Setup TMA for B - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - initial_b, - b_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - ) - # Setup TMA for SFA - sfa_smem_layout = cute.slice_( - sfa_smem_layout_staged, (None, None, None, 0) - ) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - initial_sfa, - sfa_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - internal_type=cutlass.Int16, - ) - # Setup TMA for SFB - sfb_smem_layout = cute.slice_( - sfb_smem_layout_staged, (None, None, None, 0) - ) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE), - initial_sfb, - sfb_smem_layout, - mma_tiler_mnk, - tiled_mma, - cluster_layout_vmnk.shape, - internal_type=cutlass.Int16, - ) - - # Compute TMA load bytes - a_copy_size = cute.size_in_bytes(ab_dtype, a_smem_layout) - b_copy_size = cute.size_in_bytes(ab_dtype, b_smem_layout) - sfa_copy_size = cute.size_in_bytes(sf_dtype, sfa_smem_layout) - sfb_copy_size = cute.size_in_bytes(sf_dtype, sfb_smem_layout) - num_tma_load_bytes = ( - a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size - ) * atom_thr_size - - # Store CTA shape information for each Group in a List - cta_mn_list = [] - for group_idx, (m, n, k, l) in enumerate(problem_sizes): - x, y = cute.ceil_div(problem_sizes[group_idx][:2], mma_tiler_mnk[0:2]) - cta_mn_list.append((x, y)) - - # Compute grid size - grid = (1, 1, total_num_clusters) - - # Launch the kernel - kernel( - # MMA (Matrix Multiply-Accumulate) configuration - tiled_mma, # Tiled MMA object defining NVFP4 GEMM compute pattern - - # TMA (Tensor Memory Accelerator) atoms and tensors for input matrix A - tma_atom_a, # TMA copy atom defining how to load A from global memory - tma_tensor_a, # Tensor descriptor for A (created from smallest A tensor) - - # TMA atoms and tensors for input matrix B - tma_atom_b, # TMA copy atom defining how to load B from global memory - tma_tensor_b, # Tensor descriptor for B (created from smallest B tensor) - - # TMA atoms and tensors for scale factor A - tma_atom_sfa, # TMA copy atom for loading scale factors for A - tma_tensor_sfa, # Tensor descriptor for SFA (block scale factors for A) - - # TMA atoms and tensors for scale factor B - tma_atom_sfb, # TMA copy atom for loading scale factors for B - tma_tensor_sfb, # Tensor descriptor for SFB (block scale factors for B) - - # Runtime tensor metadata for dynamic group access - tensor_of_abc_ptrs, # Device tensor containing pointers to A, B, C for all groups - tensor_of_sfasfb_ptrs, # Device tensor containing pointers to SFA, SFB for all groups - tensor_of_tensormap, # Pre-allocated buffer for tensormap descriptors per CTA - tensor_of_problem_sizes, # Device tensor containing (m, n, k, l) for each group - - # Shared memory layouts with staging for pipelined execution - a_smem_layout_staged, # Staged shared memory layout for A (includes stage dimension) - b_smem_layout_staged, # Staged shared memory layout for B (includes stage dimension) - sfa_smem_layout_staged, # Staged shared memory layout for SFA (includes stage dimension) - sfb_smem_layout_staged, # Staged shared memory layout for SFB (includes stage dimension) - - # CTA grid configuration per group - cta_mn_list, # List of (M_tiles, N_tiles) for each group - - # Pipeline synchronization parameter - num_tma_load_bytes, # Total bytes to load per TMA transaction (for barrier setup) - ).launch( - grid=grid, - block=[threads_per_cta, 1, 1], - cluster=(1, 1, 1), - ) - return - - -# Global cache for compiled kernels (keyed by group size) -_compiled_kernel_cache = {} -# This function is used to compile the kernel once and cache it and then allow users to -# run the kernel multiple times to get more accurate timing results. -def compile_kernel(problem_sizes): - """ - Compile the kernel once and cache it using problem_sizes as the key. - This should be called before any timing measurements. - - Returns: - The compiled kernel function - """ - global _compiled_kernel_cache - - # Convert problem_sizes list to a hashable tuple for use as dictionary key - cache_key = f"{len(problem_sizes)}" - - # Check if we already have a compiled kernel for these problem sizes - if cache_key in _compiled_kernel_cache: - return _compiled_kernel_cache[cache_key] - - cute_ptr_of_tensor_of_problem_sizes = make_ptr( - cutlass.Int32, 0, cute.AddressSpace.gmem, assumed_align=16, - ) - cute_ptr_of_tensor_of_abc_ptrs = make_ptr( - cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, - ) - cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( - cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, - ) - # Fake cluster numbers for compile only. - total_num_clusters = cutlass.Int32(1) - num_groups = cutlass.Int32(len(problem_sizes)) - # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) - # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) - cute_ptr_of_tensor_of_tensormap = make_ptr( - cutlass.Int64, 0, cute.AddressSpace.gmem, assumed_align=16, - ) - compiled_func = cute.compile( - my_kernel, - cute_ptr_of_tensor_of_problem_sizes, - cute_ptr_of_tensor_of_abc_ptrs, - cute_ptr_of_tensor_of_sfasfb_ptrs, - cute_ptr_of_tensor_of_tensormap, - total_num_clusters, - problem_sizes, - num_groups - ) - # Store compiled kernel in cache with problem_sizes as key - _compiled_kernel_cache[cache_key] = compiled_func - return compiled_func - - -def custom_kernel(data: input_t) -> output_t: - """ - Execute the block-scaled group GEMM kernel. - - This is the main entry point called by the evaluation framework. - It converts PyTorch tensors to CuTe tensors, launches the kernel, - and returns the result. - - Args: - data: Tuple of (abc_tensors, sfasfb_tensors, problem_sizes) where: - abc_tensors: list of tuples (a, b, c) where - a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] - b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] - c is torch.Tensor[float16] of shape [m, n, l] - sfasfb_tensors: list of tuples (sfa, sfb) where - sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] - sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] - problem_sizes: list of tuples (m, n, k, l) - each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes - l should always be 1 for each group. - list size is the number of groups. - - Returns: - list of c tensors where c is torch.Tensor[float16] of shape [m, n, l] for each group - """ - abc_tensors, _, sfasfb_reordered_tensors, problem_sizes = data - - compiled_func = compile_kernel(problem_sizes) - - # Extract raw data pointers from all input tensors for each group - # These will be passed to the GPU kernel to access the actual tensor data - abc_ptrs = [] - sfasfb_ptrs = [] - for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): - # Store pointers to A, B, and C matrices for this group - abc_ptrs.append((a.data_ptr(), b.data_ptr(), c.data_ptr())) - # Store pointers to scale factor tensors for this group - sfasfb_ptrs.append((sfa_reordered.data_ptr(), sfb_reordered.data_ptr())) - - # Create torch tensor to store problem sizes for all groups - # Shape: (num_groups, 4) where each row contains (m, n, k, l) for that group - # Layout: (num_groups, 4):(4, 1) means row-major storage - tensor_of_problem_sizes = torch.tensor( - problem_sizes, dtype=torch.int32, device="cuda" - ) - - # Create torch tensors to store data pointers for all groups - # These allow the GPU kernel to dynamically access different tensors per group - # tensor_of_abc_ptrs: Shape (num_groups, 3) containing (a_ptr, b_ptr, c_ptr) per group - # tensor_of_sfasfb_ptrs: Shape (num_groups, 2) containing (sfa_ptr, sfb_ptr) per group - tensor_of_abc_ptrs = torch.tensor(abc_ptrs, dtype=torch.int64, device="cuda") - tensor_of_sfasfb_ptrs = torch.tensor(sfasfb_ptrs, dtype=torch.int64, device="cuda") - - # Compute the tile shape for each CUDA Thread Block (CTA) - # cta_tile_shape_mn: [M_tile, N_tile] = [128, 128] for this kernel - cta_tile_shape_mn = [128, mma_tiler_mnk[1]] - # cluster_tile_shape_mn: Total tile shape per cluster (same as CTA since cluster is 1x1) - cluster_tile_shape_mn = tuple( - x * y for x, y in zip(cta_tile_shape_mn, (1, 1)) - ) - - # Compute total number of cluster tiles needed across all groups - # Each group's (m, n) dimensions are divided into tiles of size cluster_tile_shape_mn - # This determines the total grid size (bidz dimension) for kernel launch - total_num_clusters = 0 - num_groups = len(problem_sizes) - for m, n, _, _ in problem_sizes: - # Calculate number of tiles needed in M and N dimensions for this group - num_clusters_mn = tuple( - (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) - ) - # Multiply M_tiles * N_tiles to get total tiles for this group - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - - # Allocate device memory for tensormap descriptors - # Each cluster needs its own set of tensormaps (one for A, B, SFA, SFB) - # Shape: (total_num_clusters, num_tensormaps=4, bytes_per_tensormap/8=16) - # Tensormaps are hardware descriptors used by TMA for efficient memory transfers - tensormap_shape = ( - total_num_clusters, - num_tensormaps, - bytes_per_tensormap // 8, - ) - tensor_of_tensormap = torch.empty(tensormap_shape, dtype=torch.int64, device="cuda") - - # Create CuTe pointers to the metadata tensors that will be passed to the kernel - # These allow the GPU kernel to read problem sizes and tensor pointers - cute_ptr_of_tensor_of_abc_ptrs = make_ptr( - cutlass.Int64, - tensor_of_abc_ptrs.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - cute_ptr_of_tensor_of_sfasfb_ptrs = make_ptr( - cutlass.Int64, - tensor_of_sfasfb_ptrs.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - cute_ptr_of_tensor_of_problem_sizes = make_ptr( - cutlass.Int32, - tensor_of_problem_sizes.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - cute_ptr_of_tensor_of_tensormap = make_ptr( - cutlass.Int64, - tensor_of_tensormap.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - - # Launch the JIT-compiled GPU kernel with all prepared data - # The kernel will perform block-scaled group GEMM: C = A * SFA * B * SFB for all groups - compiled_func( - cute_ptr_of_tensor_of_problem_sizes, # Pointer to problem sizes array - cute_ptr_of_tensor_of_abc_ptrs, # Pointer to ABC tensor pointers array - cute_ptr_of_tensor_of_sfasfb_ptrs, # Pointer to scale factor pointers array - cute_ptr_of_tensor_of_tensormap, # Pointer to tensormap buffer - total_num_clusters, # Total number of CTAs to launch - problem_sizes, # Problem sizes list (for host-side processing) - num_groups, # Number of groups in this batch - ) - - res = [] - for i in range(num_groups): - res.append(abc_tensors[i][2]) - return res \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.py b/problems/nvidia/nvfp4_group_gemm/task.py deleted file mode 100644 index 94c11435..00000000 --- a/problems/nvidia/nvfp4_group_gemm/task.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from typing import TypedDict, TypeVar - -input_t = TypeVar("input_t", bound=tuple[list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[torch.Tensor, torch.Tensor]], list[tuple[int, int, int, int]]]) -output_t = TypeVar("output_t", bound=list[torch.Tensor]) -class TestSpec(TypedDict): - problem_sizes: list[tuple[int, int, int, int]] - seed: int \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/task.yml b/problems/nvidia/nvfp4_group_gemm/task.yml deleted file mode 100644 index a41302c0..00000000 --- a/problems/nvidia/nvfp4_group_gemm/task.yml +++ /dev/null @@ -1,67 +0,0 @@ -# name: nvfp4-block-scaled-gemm - -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "task.py", "source": "task.py"} - - {"name": "utils.py", "source": "utils.py"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "../eval_better_bench_grouped_gemm.py"} - -lang: "py" - -description: | - - You will implement a block scaled group matrix-matrix multiplication kernel optimized for NVIDIA B200. - To be explicit, you will be given a tuple of tensors: - ``` - (abc_tensors, sfasfb_tensors, problem_sizes) - ``` - where: - * `abc_tensors` is list of tuples (a, b, c) where - a is torch.Tensor[float4e2m1fn_x2] of shape [M, K // 2, L] - b is torch.Tensor[float4e2m1fn_x2] of shape [N, K // 2, L] - c is torch.Tensor[float16] of shape [M, N, L] - * `sfasfb_tensors` is list of tuples (sfa, sfb) where - sfa is torch.Tensor[float8_e4m3fnuz] of shape [M, K // 16, L] - sfb is torch.Tensor[float8_e4m3fnuz] of shape [N, K // 16, L] - * `problem_sizes` is list of tuples (M, N, K, L) - - Each group's matrix sizes `M` is divisible by mma_tiler_mn[0], `N` is divisible by mma_tiler_mn[1], `K` is divisible by 256. - The ranking criteria is the geometric mean of the benchmark results. - For the grand price, your kernel will be evaluated against the speed of light analysis - and the solution closest to the speed of light will be awarded the grand price. - ``` - The speed of light analysis based on the max(FP4 Tensor Core math throughput, DRAM memory throughput) of B200 and tested under 1.5Ghz clock with the average M, N, K values per group: - G M_values N_values K_values L time[us] - 8 [80, 176, 128, 72, 64, 248, 96, 160] [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096] [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168] 1 18.833 - 8 [40, 76, 168, 72, 164, 148, 196, 160] [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168] [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048] 1 10.667 - 2 [192, 320] [3072, 3072] [4096, 4096] 1 2.406 - 2 [128, 384] [4096, 4096] [1536, 1536] 1 1.525 - ``` -config: - main: "eval.py" - -templates: - Python: "template.py" - -tests: - - {"m": [96, 128], "n": [128, 256], "k": [256, 512], "g": 2, "seed": 1111} - - {"m": [256, 72], "n": [512, 384], "k": [256, 256], "g": 2, "seed": 1111} - - {"m": [128, 128], "n": [128, 256], "k": [512, 256], "g": 2, "seed": 1111} - - {"m": [80, 128, 256], "n": [384, 256, 128], "k": [256, 512, 256], "g": 3, "seed": 1111} - - {"m": [64, 72, 96], "n": [128, 384, 512], "k": [512, 512, 256], "g": 3, "seed": 1111} - - {"m": [64, 256, 128], "n": [768, 128, 256], "k": [512, 256, 512], "g": 3, "seed": 1111} - - {"m": [128, 128, 64], "n": [256, 512, 512], "k": [768, 256, 768], "g": 3, "seed": 1111} - - {"m": [128, 128, 128, 128], "n": [128, 128, 128, 128], "k": [512, 256, 512, 256], "g": 4, "seed": 1111} - - {"m": [40, 56, 384, 512], "n": [512, 384, 256, 128], "k": [256, 256, 256, 256], "g": 4, "seed": 1111} - - {"m": [512, 384, 256, 128], "n": [256, 256, 256, 256], "k": [512, 768, 512, 768], "g": 4, "seed": 1111} - -benchmarks: - - {"m": [80, 176, 128, 72, 64, 248, 96, 160], "n": [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096], "k": [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], "g": 8, "seed": 1111} - - {"m": [40, 76, 168, 72, 164, 148, 196, 160], "n": [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], "k": [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], "g": 8, "seed": 1111} - - {"m": [192, 320], "n": [3072, 3072], "k": [4096, 4096], "g": 2, "seed": 1111} - - {"m": [128, 384], "n": [4096, 4096], "k": [1536, 1536], "g": 2, "seed": 1111} - -ranking_by: "geom" - -ranked_timeout: 300 diff --git a/problems/nvidia/nvfp4_group_gemm/template.py b/problems/nvidia/nvfp4_group_gemm/template.py deleted file mode 100644 index b6005faa..00000000 --- a/problems/nvidia/nvfp4_group_gemm/template.py +++ /dev/null @@ -1,31 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - """ - Reference implementation of block-scale fp4 group gemm - Args: - data: list of tuples (abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes) where: - abc_tensors: list of tuples (a, b, c) where - a is torch.Tensor[float4e2m1fn_x2] of shape [m, k // 2, l] - b is torch.Tensor[float4e2m1fn_x2] of shape [n, k // 2, l] - c is torch.Tensor[float16] of shape [m, n, l] - sfasfb_tensors: list of tuples (sfa, sfb) where - sfa is torch.Tensor[float8_e4m3fnuz] of shape [m, k // 16, l] - sfb is torch.Tensor[float8_e4m3fnuz] of shape [n, k // 16, l] - sfasfb_reordered_tensors: list of tuples (sfa_reordered, sfb_reordered) where - sfa_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_m, 4, rest_k, l] - sfb_reordered is torch.Tensor[float8_e4m3fnuz] of shape [32, 4, rest_n, 4, rest_k, l] - problem_sizes: list of tuples (m, n, k, l) - each group has its own a, b, c, sfa, sfb with different m, n, k, l problem sizes - l should always be 1 for each group. - Returns: - list of tuples (c) where c is torch.Tensor[float16] of shape [m, n, l] - """ - abc_tensors, sfasfb_tensors, sfasfb_reordered_tensors, problem_sizes = data - result_tensors = [] - for i, ((a, b, c), (sfa_reordered, sfb_reordered), (m, n, k, l)) in enumerate(zip(abc_tensors, sfasfb_reordered_tensors, problem_sizes)): - # add you implementation here - result_tensors.append(c) - - return result_tensors \ No newline at end of file diff --git a/problems/nvidia/nvfp4_group_gemm/utils.py b/problems/nvidia/nvfp4_group_gemm/utils.py deleted file mode 100644 index f6c1a8b3..00000000 --- a/problems/nvidia/nvfp4_group_gemm/utils.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import random -import numpy as np -import torch - - -def set_seed(seed=42): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_device(use_cuda: bool = True) -> torch.device: - """Get the appropriate device (GPU or CPU).""" - if use_cuda: - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - print("No compatible GPU found. Falling back to CPU.") - return torch.device("cpu") - - -# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py -@torch.no_grad() -def verbose_allclose( - received: torch.Tensor, expected: torch.Tensor, rtol=1e-05, atol=1e-08, max_print=5 -) -> list[str]: - """ - Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - rtol (float): Relative tolerance; relative to expected - atol (float): Absolute tolerance. - max_print (int): Maximum number of mismatched elements to print. - - Raises: - AssertionError: If the tensors are not all close within the given tolerance. - """ - # Check if the shapes of the tensors match - if received.shape != expected.shape: - return ["SIZE MISMATCH"] - - # Calculate the difference between the tensors - diff = torch.abs(received - expected) - - # Determine the tolerance - tolerance = atol + rtol * torch.abs(expected) - - # Find tolerance mismatched elements - tol_mismatched = diff > tolerance - - # Find nan mismatched elements - nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) - - # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor( - torch.isposinf(received), torch.isposinf(expected) - ) - # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor( - torch.isneginf(received), torch.isneginf(expected) - ) - - # Find all mismatched elements - mismatched = torch.logical_or( - torch.logical_or(tol_mismatched, nan_mismatched), - torch.logical_or(posinf_mismatched, neginf_mismatched), - ) - - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append( - f"... and {num_mismatched - max_print} more mismatched elements." - ) - return mismatch_details - - return [] - - -@torch.no_grad() -def verbose_allequal( - received: torch.Tensor, expected: torch.Tensor, max_print: int = 5 -): - """ - Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. - - Parameters: - received (torch.Tensor): Tensor we actually got. - expected (torch.Tensor): Tensor we expected to receive. - max_print (int): Maximum number of mismatched elements to print. - - Returns: - Empty string if tensors are equal, otherwise detailed error information - """ - mismatched = torch.not_equal(received, expected) - mismatched_indices = torch.nonzero(mismatched) - - # Count the number of mismatched elements - num_mismatched = mismatched.count_nonzero().item() - - # Generate detailed information if there are mismatches - if num_mismatched >= 1: - mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] - - for index in mismatched_indices[:max_print]: - i = tuple(index.tolist()) - mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") - if num_mismatched > max_print: - mismatch_details.append( - f"... and {num_mismatched - max_print} more mismatched elements." - ) - return mismatch_details - - return [] - - -def match_reference( - data, output, reference: callable, rtol=1e-05, atol=1e-08 -) -> tuple[bool, str]: - """ - Convenient "default" implementation for tasks' `check_implementation` function. - """ - expected = reference(data) - - if len(output) != len(expected): - return ( - False, - f"output length mismatch: got {len(output)}, expected {len(expected)}", - ) - - for i, (output_i, expected_i) in enumerate(zip(output, expected)): - reasons = verbose_allclose(output_i, expected_i, rtol=rtol, atol=atol) - if len(reasons) > 0: - return ( - False, - f"mismatch found! custom implementation doesn't match reference: {i} {reasons}", - ) - - return True, "" - - -def make_match_reference(reference: callable, **kwargs): - def wrapped(data, output): - return match_reference(data, output, reference=reference, **kwargs) - - return wrapped - - -class DeterministicContext: - def __init__(self): - self.allow_tf32 = None - self.deterministic = None - self.cublas = None - - def __enter__(self): - self.cublas = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") - self.allow_tf32 = torch.backends.cudnn.allow_tf32 - self.deterministic = torch.backends.cudnn.deterministic - torch.backends.cudnn.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.use_deterministic_algorithms(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - torch.backends.cudnn.allow_tf32 = self.allow_tf32 - torch.backends.cudnn.deterministic = self.deterministic - torch.use_deterministic_algorithms(False) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.cublas - - -def clear_l2_cache(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") - # write stuff to - dummy.fill_(42) - del dummy - - -def clear_l2_cache_large(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.randn((16000, 1024, 1024), device="cuda") - del dummy diff --git a/problems/nvidia/utils.py b/problems/nvidia/utils.py index b2859f09..7997d3db 100644 --- a/problems/nvidia/utils.py +++ b/problems/nvidia/utils.py @@ -28,7 +28,11 @@ def get_device(use_cuda: bool = True) -> torch.device: # Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py @torch.no_grad() def verbose_allclose( - received: torch.Tensor, expected: torch.Tensor, rtol=1e-05, atol=1e-08, max_print=5 + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 ) -> list[str]: """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. @@ -60,13 +64,9 @@ def verbose_allclose( nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) # Find +inf mismatched elements - posinf_mismatched = torch.logical_xor( - torch.isposinf(received), torch.isposinf(expected) - ) + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) # Find -inf mismatched elements - neginf_mismatched = torch.logical_xor( - torch.isneginf(received), torch.isneginf(expected) - ) + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) # Find all mismatched elements mismatched = torch.logical_or( @@ -87,18 +87,14 @@ def verbose_allclose( i = tuple(index.tolist()) mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") if num_mismatched > max_print: - mismatch_details.append( - f"... and {num_mismatched - max_print} more mismatched elements." - ) + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") return mismatch_details return [] @torch.no_grad() -def verbose_allequal( - received: torch.Tensor, expected: torch.Tensor, max_print: int = 5 -): +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): """ Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. @@ -124,17 +120,13 @@ def verbose_allequal( i = tuple(index.tolist()) mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") if num_mismatched > max_print: - mismatch_details.append( - f"... and {num_mismatched - max_print} more mismatched elements." - ) + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") return mismatch_details return [] -def match_reference( - data, output, reference: callable, rtol=1e-05, atol=1e-08 -) -> tuple[bool, str]: +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: """ Convenient "default" implementation for tasks' `check_implementation` function. """ @@ -142,19 +134,14 @@ def match_reference( reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) if len(reasons) > 0: - return ( - False, - "mismatch found! custom implementation doesn't match reference: " - + " ".join(reasons), - ) + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) - return True, "" + return True, '' def make_match_reference(reference: callable, **kwargs): def wrapped(data, output): return match_reference(data, output, reference=reference, **kwargs) - return wrapped @@ -165,7 +152,7 @@ def __init__(self): self.cublas = None def __enter__(self): - self.cublas = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') self.allow_tf32 = torch.backends.cudnn.allow_tf32 self.deterministic = torch.backends.cudnn.deterministic torch.backends.cudnn.allow_tf32 = False @@ -177,8 +164,7 @@ def __exit__(self, exc_type, exc_value, traceback): torch.backends.cudnn.allow_tf32 = self.allow_tf32 torch.backends.cudnn.deterministic = self.deterministic torch.use_deterministic_algorithms(False) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.cublas - + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas def clear_l2_cache(): # import cupy as cp @@ -186,11 +172,3 @@ def clear_l2_cache(): # create a large dummy tensor dummy = torch.randn((1024, 1024, 1024), device="cuda") del dummy - - -def clear_l2_cache_large(): - # import cupy as cp - # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) - # create a large dummy tensor - dummy = torch.randn((16000, 1024, 1024), device="cuda") - del dummy diff --git a/problems/pmpp/vectoradd_py/submission.py b/problems/pmpp/vectoradd_py/submission.py deleted file mode 100644 index 0d2ad435..00000000 --- a/problems/pmpp/vectoradd_py/submission.py +++ /dev/null @@ -1,6 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - A, B = data - return A + B diff --git a/problems/pmpp_v2.yaml b/problems/pmpp_v2.yaml index 3bdf6777..e3c6915c 100644 --- a/problems/pmpp_v2.yaml +++ b/problems/pmpp_v2.yaml @@ -7,7 +7,7 @@ description: "" problems: - directory: pmpp_v2/conv2d_py name: conv2d_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -15,7 +15,7 @@ problems: - L4 - directory: pmpp_v2/grayscale_py name: grayscale_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -23,7 +23,7 @@ problems: - L4 - directory: pmpp_v2/histogram_py name: histogram_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -31,7 +31,7 @@ problems: - L4 - directory: pmpp_v2/matmul_py name: matmul_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -39,7 +39,7 @@ problems: - L4 - directory: pmpp_v2/prefixsum_py name: prefixsum_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -47,7 +47,7 @@ problems: - L4 - directory: pmpp_v2/sort_py name: sort_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -55,7 +55,7 @@ problems: - L4 - directory: pmpp_v2/vectoradd_py name: vectoradd_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 @@ -63,7 +63,7 @@ problems: - L4 - directory: pmpp_v2/vectorsum_py name: vectorsum_v2 - deadline: "2100-12-31" + deadline: "2025-12-30" gpus: - B200 - H100 diff --git a/problems/pmpp_v2/conv2d_py/task.yml b/problems/pmpp_v2/conv2d_py/task.yml index 9b44b2b7..55adc532 100644 --- a/problems/pmpp_v2/conv2d_py/task.yml +++ b/problems/pmpp_v2/conv2d_py/task.yml @@ -44,4 +44,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/grayscale_py/task.yml b/problems/pmpp_v2/grayscale_py/task.yml index d1cbb30a..cada0257 100644 --- a/problems/pmpp_v2/grayscale_py/task.yml +++ b/problems/pmpp_v2/grayscale_py/task.yml @@ -38,4 +38,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/histogram_py/task.yml b/problems/pmpp_v2/histogram_py/task.yml index 419529ab..489a98b6 100644 --- a/problems/pmpp_v2/histogram_py/task.yml +++ b/problems/pmpp_v2/histogram_py/task.yml @@ -40,4 +40,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/matmul_py/task.yml b/problems/pmpp_v2/matmul_py/task.yml index 864ba171..6924764b 100644 --- a/problems/pmpp_v2/matmul_py/task.yml +++ b/problems/pmpp_v2/matmul_py/task.yml @@ -41,4 +41,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/prefixsum_py/task.yml b/problems/pmpp_v2/prefixsum_py/task.yml index 734546d3..a91d1496 100644 --- a/problems/pmpp_v2/prefixsum_py/task.yml +++ b/problems/pmpp_v2/prefixsum_py/task.yml @@ -54,4 +54,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/sort_py/task.yml b/problems/pmpp_v2/sort_py/task.yml index 7e78a156..5c702e29 100644 --- a/problems/pmpp_v2/sort_py/task.yml +++ b/problems/pmpp_v2/sort_py/task.yml @@ -38,4 +38,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py index ecd070b4..d6f71050 100644 --- a/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py +++ b/problems/pmpp_v2/vectoradd_py/solutions/correct/submission_cuda_inline.py @@ -48,7 +48,7 @@ add_cpp_source = """ #include -torch::Tensor add_cuda(torch::Tensor A, torch::Tensor B, torch::Tensor C); +torch::Tensor add_cuda(torch::Tensor A, torch::Tensor B); """ add_module = load_inline( @@ -59,10 +59,10 @@ verbose=True, ) -def add(A, B, C): - if not A.is_cuda or not B.is_cuda or not C.is_cuda: - raise RuntimeError("All tensors must be on GPU") - return add_module.add_cuda(A, B, C) +def add(A, B): + if not A.is_cuda or not B.is_cuda: + raise RuntimeError("Both tensors must be on GPU") + return add_module.add_cuda(A, B) def custom_kernel(data: input_t) -> output_t: """ @@ -72,13 +72,12 @@ def custom_kernel(data: input_t) -> output_t: Returns: Tensor containing element-wise sum. """ - A, B, C = data + A, B = data - assert A.is_cuda and B.is_cuda and C.is_cuda, "Input/output tensors must be on GPU" + assert A.is_cuda and B.is_cuda, "Input tensors must be on GPU" assert A.shape == B.shape, "Input tensors must have the same shape" - assert C.shape == A.shape, "Output tensor and input tensors must have the same shape" - assert A.dtype == torch.float16 and B.dtype == torch.float16 and C.dtype == torch.float16, "Input/output tensors must be float16" + assert A.dtype == torch.float16 and B.dtype == torch.float16, "Input tensors must be float16" # Simply reuse the existing add function we already defined # This avoids the compilation issues with the inline kernel - return add(A, B, C) + return add(A, B) diff --git a/problems/pmpp_v2/vectoradd_py/submission.py b/problems/pmpp_v2/vectoradd_py/submission.py deleted file mode 100644 index 918a1eb8..00000000 --- a/problems/pmpp_v2/vectoradd_py/submission.py +++ /dev/null @@ -1,7 +0,0 @@ -from task import input_t, output_t - - -def custom_kernel(data: input_t) -> output_t: - A, B, output = data - output[...] = A + B - return output diff --git a/problems/pmpp_v2/vectoradd_py/task.yml b/problems/pmpp_v2/vectoradd_py/task.yml index f662f16a..6906a313 100644 --- a/problems/pmpp_v2/vectoradd_py/task.yml +++ b/problems/pmpp_v2/vectoradd_py/task.yml @@ -38,4 +38,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/pmpp_v2/vectorsum_py/task.yml b/problems/pmpp_v2/vectorsum_py/task.yml index fc752a83..8b3ddbb7 100644 --- a/problems/pmpp_v2/vectorsum_py/task.yml +++ b/problems/pmpp_v2/vectorsum_py/task.yml @@ -38,4 +38,4 @@ benchmarks: test_timeout: 180 benchmark_timeout: 180 -ranked_timeout: 420 +ranked_timeout: 180 diff --git a/problems/princeton/cross_entropy_py/eval.py b/problems/princeton/cross_entropy_py/eval.py deleted file mode 100644 index 65124d54..00000000 --- a/problems/princeton/cross_entropy_py/eval.py +++ /dev/null @@ -1,351 +0,0 @@ -import dataclasses -import math -import os -import random -import re -import statistics -import sys -from pathlib import Path - -import torch -import torch.nn.functional as F - -from reference import ( - ATOL, - DTYPE, - RTOL, - generate_inputs, - reference_backward, - reference_forward, -) - - -# Original eval parameters -B = 4_096 -WARMUP_ITERS = 20 -BENCH_ITERS = 100 - - -def make_seed_schedule(): - total = WARMUP_ITERS + 3 * BENCH_ITERS - seeds = random.SystemRandom().sample(range(1, 2**31 - 1), total) - warmup_end = WARMUP_ITERS - forward_end = warmup_end + BENCH_ITERS - backward_end = forward_end + BENCH_ITERS - return { - "warmup": seeds[:warmup_end], - "forward": seeds[warmup_end:forward_end], - "backward": seeds[forward_end:backward_end], - "combined": seeds[backward_end:], - } - - -class PopcornOutput: - def __init__(self, fd: int): - self.file = os.fdopen(fd, "w") - os.set_inheritable(fd, False) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def print(self, *args, **kwargs): - print(*args, **kwargs, file=self.file, flush=True) - - def log(self, key, value): - self.print(f"{key}: {value}") - - -@dataclasses.dataclass -class TestCase: - args: dict - spec: str - - -@dataclasses.dataclass -class Stats: - runs: int - mean: float - std: float - err: float - best: float - worst: float - fwd_bw: float - bwd_bw: float - combined_bw: float - - -def get_test_cases(file_name: str) -> list[TestCase]: - try: - content = Path(file_name).read_text() - except Exception as exc: - print(f"Could not open test file `{file_name}`: {exc}", file=sys.stderr) - sys.exit(113) - - tests = [] - lines = content.splitlines() - match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z_]+|[+-]?[0-9]+)\s*" - for line in lines: - if not line.strip(): - continue - parts = line.split(";") - case = {} - for part in parts: - matched = re.match(match, part) - if not re.fullmatch(match, part): - print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) - sys.exit(113) - key = matched[1] - value = matched[2] - try: - value = int(value) - except ValueError: - pass - case[key] = value - tests.append(TestCase(spec=line, args=case)) - return tests - - -def load_submission(): - import submission - - for fn_name in ("cross_entropy_forward", "cross_entropy_backward"): - if not hasattr(submission, fn_name): - raise AttributeError(f"Submission is missing function '{fn_name}'.") - return submission - - -def check_correctness(mod, vocab_size): - logits, targets, grad_output = generate_inputs(B, vocab_size) - - ref_loss = reference_forward(logits, targets) - sub_loss = mod.cross_entropy_forward(logits, targets) - - assert sub_loss.shape == ref_loss.shape, ( - f"Forward shape mismatch: expected {ref_loss.shape}, got {sub_loss.shape}" - ) - assert sub_loss.dtype == torch.float32, ( - f"Forward dtype mismatch: expected float32, got {sub_loss.dtype}" - ) - - fwd_close = torch.allclose(sub_loss, ref_loss, atol=ATOL, rtol=RTOL) - max_fwd_err = (sub_loss - ref_loss).abs().max().item() - - ref_grad = reference_backward(logits, targets, grad_output) - sub_grad = mod.cross_entropy_backward(logits, targets, grad_output) - - assert sub_grad.shape == ref_grad.shape, ( - f"Backward shape mismatch: expected {ref_grad.shape}, got {sub_grad.shape}" - ) - assert sub_grad.dtype == DTYPE, ( - f"Backward dtype mismatch: expected {DTYPE}, got {sub_grad.dtype}" - ) - - bwd_close = torch.allclose(sub_grad, ref_grad, atol=ATOL, rtol=RTOL) - max_bwd_err = (sub_grad.float() - ref_grad.float()).abs().max().item() - - return fwd_close, bwd_close, max_fwd_err, max_bwd_err - - -def benchmark_one(mod, vocab_size, seed_schedule): - def phase_inputs(phase, idx): - seed = seed_schedule[phase][idx] - return generate_inputs(B, vocab_size, seed=seed) - - for idx in range(WARMUP_ITERS): - logits, targets, grad_output = phase_inputs("warmup", idx) - mod.cross_entropy_forward(logits, targets) - mod.cross_entropy_backward(logits, targets, grad_output) - torch.cuda.synchronize() - - fwd_times = [] - for idx in range(BENCH_ITERS): - logits, targets, _ = phase_inputs("forward", idx) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - mod.cross_entropy_forward(logits, targets) - end.record() - torch.cuda.synchronize() - fwd_times.append(start.elapsed_time(end)) - - bwd_times = [] - for idx in range(BENCH_ITERS): - logits, targets, grad_output = phase_inputs("backward", idx) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - mod.cross_entropy_backward(logits, targets, grad_output) - end.record() - torch.cuda.synchronize() - bwd_times.append(start.elapsed_time(end)) - - combined_times = [] - for idx in range(BENCH_ITERS): - logits, targets, grad_output = phase_inputs("combined", idx) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - mod.cross_entropy_forward(logits, targets) - mod.cross_entropy_backward(logits, targets, grad_output) - end.record() - torch.cuda.synchronize() - combined_times.append(start.elapsed_time(end)) - - fwd_ms = statistics.median(fwd_times) - bwd_ms = statistics.median(bwd_times) - combined_ms = statistics.median(combined_times) - - fwd_bytes = 2 * B * vocab_size + 12 * B - bwd_bytes = 4 * B * vocab_size + 12 * B - total_bytes = fwd_bytes + bwd_bytes - - fwd_bw = fwd_bytes / (fwd_ms * 1e-3) / 1e9 - bwd_bw = bwd_bytes / (bwd_ms * 1e-3) / 1e9 - combined_bw = total_bytes / (combined_ms * 1e-3) / 1e9 - - # Keep KernelBot scoring on the exact reported metric: median combined ms. - return Stats( - runs=BENCH_ITERS, - mean=combined_ms * 1e6, - std=statistics.pstdev(combined_times) * 1e6, - err=(statistics.pstdev(combined_times) / math.sqrt(len(combined_times))) * 1e6, - best=min(combined_times) * 1e6, - worst=max(combined_times) * 1e6, - fwd_bw=fwd_bw, - bwd_bw=bwd_bw, - combined_bw=combined_bw, - ) - - -def run_testing(logger: PopcornOutput, tests: list[TestCase]) -> int: - try: - mod = load_submission() - except Exception as exc: - logger.log("check", "fail") - logger.log("error", str(exc)) - return 112 - - passed = True - logger.log("test-count", len(tests)) - for idx, test in enumerate(tests): - vocab_size = int(test.args["vocab_size"]) - logger.log(f"test.{idx}.spec", test.spec) - try: - fwd_ok, bwd_ok, fwd_err, bwd_err = check_correctness(mod, vocab_size) - if fwd_ok and bwd_ok: - logger.log(f"test.{idx}.status", "pass") - logger.log( - f"test.{idx}.message", - f"forward max err={fwd_err:.3e}, backward max err={bwd_err:.3e}", - ) - else: - logger.log(f"test.{idx}.status", "fail") - logger.log( - f"test.{idx}.error", - f"forward max err={fwd_err:.3e} {'OK' if fwd_ok else 'FAIL'}; " - f"backward max err={bwd_err:.3e} {'OK' if bwd_ok else 'FAIL'}", - ) - passed = False - except Exception as exc: - logger.log(f"test.{idx}.status", "fail") - logger.log(f"test.{idx}.error", str(exc)) - passed = False - - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def run_benchmarking(logger: PopcornOutput, tests: list[TestCase]) -> int: - try: - mod = load_submission() - except Exception as exc: - logger.log("check", "fail") - logger.log("error", str(exc)) - return 112 - - baseline_mod = type(sys)("baseline") - baseline_mod.cross_entropy_forward = ( - lambda logits, targets: F.cross_entropy(logits.float(), targets, reduction="none") - ) - - def baseline_bwd(logits, targets, grad_output): - probs = torch.softmax(logits.float(), dim=-1) - probs[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0 - return (probs * grad_output.unsqueeze(1)).to(logits.dtype) - - baseline_mod.cross_entropy_backward = baseline_bwd - - passed = True - logger.log("benchmark-count", len(tests)) - for idx, test in enumerate(tests): - vocab_size = int(test.args["vocab_size"]) - logger.log(f"benchmark.{idx}.spec", test.spec) - try: - seed_schedule = make_seed_schedule() - baseline = benchmark_one(baseline_mod, vocab_size, seed_schedule) - result = benchmark_one(mod, vocab_size, seed_schedule) - speedup = baseline.mean / result.mean - except Exception as exc: - logger.log(f"benchmark.{idx}.status", "fail") - logger.log(f"benchmark.{idx}.error", str(exc)) - passed = False - continue - - logger.log(f"benchmark.{idx}.runs", result.runs) - logger.log(f"benchmark.{idx}.mean", result.mean) - logger.log(f"benchmark.{idx}.std", result.std) - logger.log(f"benchmark.{idx}.err", result.err) - logger.log(f"benchmark.{idx}.best", result.best) - logger.log(f"benchmark.{idx}.worst", result.worst) - logger.log(f"benchmark.{idx}.fwd_bw", result.fwd_bw) - logger.log(f"benchmark.{idx}.bwd_bw", result.bwd_bw) - logger.log(f"benchmark.{idx}.combined_bw", result.combined_bw) - logger.log(f"benchmark.{idx}.speedup", speedup) - logger.log( - f"benchmark.{idx}.message", - ( - f"fwd+bwd={result.mean / 1e6:.3f} ms, " - f"fwd_bw={result.fwd_bw:.1f} GB/s, " - f"bwd_bw={result.bwd_bw:.1f} GB/s, " - f"combined_bw={result.combined_bw:.1f} GB/s, " - f"speedup={speedup:.2f}x" - ), - ) - - logger.log("check", "pass" if passed else "fail") - return 0 if passed else 112 - - -def main(): - fd = os.getenv("POPCORN_FD") - if not fd: - return 111 - - if len(sys.argv) < 3: - return 2 - - if not torch.cuda.is_available(): - with PopcornOutput(int(fd)) as logger: - logger.log("check", "fail") - logger.log("error", "No CUDA GPU available. This script requires a GPU.") - return 112 - - mode = sys.argv[1] - tests = get_test_cases(sys.argv[2]) - - with PopcornOutput(int(fd)) as logger: - if mode == "test": - return run_testing(logger, tests) - if mode in {"benchmark", "leaderboard"}: - return run_benchmarking(logger, tests) - - logger.log("check", "fail") - logger.log("error", f"Unsupported mode: {mode}") - return 2 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/problems/princeton/cross_entropy_py/reference.py b/problems/princeton/cross_entropy_py/reference.py deleted file mode 100644 index 0317f3fa..00000000 --- a/problems/princeton/cross_entropy_py/reference.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn.functional as F - - -DTYPE = torch.bfloat16 -DEVICE = "cuda" -ATOL = 1e-3 -RTOL = 1e-2 - - -def reference_forward(logits, targets): - return F.cross_entropy(logits.float(), targets, reduction="none") - - -def reference_backward(logits, targets, grad_output): - probs = torch.softmax(logits.float(), dim=-1) - grad = probs - grad[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0 - grad = grad * grad_output.unsqueeze(1) - return grad.to(logits.dtype) - - -def generate_inputs(batch_size, vocab_size, seed=42): - torch.manual_seed(seed) - logits = torch.randn(batch_size, vocab_size, dtype=DTYPE, device=DEVICE) - targets = torch.randint(0, vocab_size, (batch_size,), device=DEVICE) - grad_output = torch.randn(batch_size, dtype=torch.float32, device=DEVICE) - return logits, targets, grad_output diff --git a/problems/princeton/cross_entropy_py/submission.py b/problems/princeton/cross_entropy_py/submission.py deleted file mode 100644 index e24b7ff2..00000000 --- a/problems/princeton/cross_entropy_py/submission.py +++ /dev/null @@ -1,52 +0,0 @@ -#!POPCORN leaderboard princeton_cross_entropy - -""" -Baseline submission for the cross-entropy problem. - -Replace these functions with a faster implementation. - -The evaluator uses: -- B = 4096 -- V in {32000, 50264, 128256} -- V % 8 == 0 -- finite real-valued logits (no masking with -inf) - -Example local bandwidth calculation for the three ranked shapes: - - def print_max_bw(batch_size, vocab_size, combined_ms): - total_bytes = (6 * batch_size * vocab_size) + (24 * batch_size) - combined_bw = total_bytes / (combined_ms * 1e-3) / 1e9 - print(f\"B={batch_size} V={vocab_size}: {combined_bw:.2f} GB/s\") - -This is only for local debugging. Do not add timing calls inside the hot path -if you care about leaderboard performance. -""" - -import torch -import torch.nn.functional as F - - -def cross_entropy_forward(logits, targets): - """ - Args: - logits: (B, V) torch.bfloat16 - targets: (B,) torch.int64 - Returns: - (B,) torch.float32 - """ - return F.cross_entropy(logits.float(), targets, reduction="none") - - -def cross_entropy_backward(logits, targets, grad_output): - """ - Args: - logits: (B, V) torch.bfloat16 - targets: (B,) torch.int64 - grad_output: (B,) torch.float32 - Returns: - (B, V) torch.bfloat16 - """ - probs = torch.softmax(logits.float(), dim=-1) - probs[torch.arange(logits.shape[0], device=logits.device), targets] -= 1.0 - grad_logits = probs * grad_output.unsqueeze(1) - return grad_logits.to(logits.dtype) diff --git a/problems/princeton/cross_entropy_py/task.yml b/problems/princeton/cross_entropy_py/task.yml deleted file mode 100644 index 1d91457e..00000000 --- a/problems/princeton/cross_entropy_py/task.yml +++ /dev/null @@ -1,48 +0,0 @@ -files: - - {"name": "submission.py", "source": "@SUBMISSION@"} - - {"name": "reference.py", "source": "reference.py"} - - {"name": "eval.py", "source": "eval.py"} - -lang: "py" - -description: | - Implement fused cross-entropy forward and backward kernels for logits of shape (B, V). - - Your submission must define: - - cross_entropy_forward(logits, targets) -> losses - - cross_entropy_backward(logits, targets, grad_output) -> grad_logits - - Inputs: - - logits: torch.bfloat16 tensor of real-valued, finite logits with shape (B, V) - - targets: torch.int64 tensor of shape (B,) - - grad_output: torch.float32 tensor of shape (B,) - - Outputs: - - forward output: torch.float32 tensor of shape (B,) - - backward output: torch.bfloat16 tensor of shape (B, V) - - Assumptions used by the evaluator and benchmark: - - batch size is fixed at B = 4096 - - vocab sizes are V in {32000, 50264, 128256} - - vocab size is guaranteed to be divisible by 8 - - logits are ordinary real numbers; masked values such as -inf are not used - -config: - main: "eval.py" - -tests: - - {"vocab_size": 32000} - - {"vocab_size": 50264} - - {"vocab_size": 128256} - -benchmarks: - - {"vocab_size": 32000} - - {"vocab_size": 50264} - - {"vocab_size": 128256} - -test_timeout: 300 -benchmark_timeout: 900 -ranked_timeout: 1200 -ranking_by: "geom" -gpus: - - A100 diff --git a/problems/princeton2026.yaml b/problems/princeton2026.yaml deleted file mode 100644 index c2f76179..00000000 --- a/problems/princeton2026.yaml +++ /dev/null @@ -1,9 +0,0 @@ -name: Princeton Problem Set -deadline: "2026-04-17 03:59" -description: "Princeton problem set" -problems: - - directory: princeton/cross_entropy_py - name: princeton_cross_entropy - deadline: "2026-04-17 03:59" - gpus: - - A100