Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions problems/linalg/qr_py/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ def _band_mask(n: int, bandwidth: int, device: torch.device) -> torch.Tensor:
return (idx[:, None] - idx[None, :]).abs() <= bandwidth


def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t:
assert batch > 0, "batch must be positive"
assert n > 0, "n must be positive"
assert cond >= 0, "cond must be non-negative"

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = torch.Generator(device=device)
gen.manual_seed(seed)

case = case.lower()
a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen)

# Per-matrix conditioning profiles drawn for the "mixed" case. "dense" is the
# well-conditioned majority; the rest are the ill-conditioned stress structures.
_MIXED_PROFILES = ("dense", "rankdef", "nearrank", "clustered", "band", "rowscale", "nearcollinear")
# Relative sampling weights (normalized by torch.multinomial); dense ~= 50%.
_MIXED_WEIGHTS = (6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)


def _apply_case(a: torch.Tensor, case: str, cond: int, gen: torch.Generator) -> torch.Tensor:
# Apply one conditioning profile to an already-drawn base batch `a` of shape
# (m, n, n), drawing any case-specific extra randomness from `gen`. Factored
# out of generate_input so the homogeneous cases and the per-matrix "mixed"
# case share a single implementation. The draw order (base first in the
# caller, then the case extras here) matches the original code, so every
# homogeneous case produces bit-for-bit identical data to before.
m, n = a.shape[0], a.shape[-1]
device = a.device
if case == "dense":
a = _apply_column_scaling(a, cond)
elif case == "upper":
Expand All @@ -40,7 +44,7 @@ def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense"
a.diagonal(dim1=-2, dim2=-1).add_(diag_boost)
a = _apply_column_scaling(a, cond)
elif case == "diagonal":
diag = torch.randn((batch, n), device=device, dtype=torch.float32, generator=gen)
diag = torch.randn((m, n), device=device, dtype=torch.float32, generator=gen)
diag = diag.sign().clamp(min=0.0).mul(2.0).sub(1.0) * torch.logspace(
0.0, -float(max(cond, 2)), n, device=device, dtype=torch.float32
)
Expand All @@ -54,7 +58,7 @@ def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense"
tail = n - rank
if tail > 0:
noise = torch.randn(
(batch, n, tail), device=device, dtype=torch.float32, generator=gen
(m, n, tail), device=device, dtype=torch.float32, generator=gen
)
a[:, :, rank:] = a[:, :, :tail] + 1.0e-5 * noise
a = _apply_column_scaling(a, cond)
Expand All @@ -73,16 +77,67 @@ def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense"
a.diagonal(dim1=-2, dim2=-1).add_(diag_boost)
a = _apply_column_scaling(a, cond)
elif case == "nearcollinear":
base = torch.randn((batch, n, 1), device=device, dtype=torch.float32, generator=gen)
noise = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen)
a = base.expand(batch, n, n) + 1.0e-4 * noise
base = torch.randn((m, n, 1), device=device, dtype=torch.float32, generator=gen)
noise = torch.randn((m, n, n), device=device, dtype=torch.float32, generator=gen)
a = base.expand(m, n, n) + 1.0e-4 * noise
a = _apply_column_scaling(a, cond)
elif case == "rowscale":
row_cond = max(cond, 4)
scales = torch.logspace(0.0, -float(row_cond), n, device=device, dtype=torch.float32)
a = scales.reshape(1, n, 1) * a
else:
raise ValueError(f"unknown QR test case: {case}")
return a


def _generate_mixed(a: torch.Tensor, cond: int, gen: torch.Generator) -> torch.Tensor:
# Heterogeneous batch: assign each matrix an independent conditioning profile
# at a RANDOM position in the batch (seeded, so still deterministic), so
# well- and ill-conditioned matrices are interleaved rather than uniform
# across the batch. This matches the real optimizer-statistics regime (the
# per-layer / per-block factors have wildly different conditioning) and it
# removes the loophole where a kernel samples a few matrices, concludes the
# whole batch is well-conditioned, and routes it all to a fast path that is
# only numerically valid for well-conditioned inputs. With a mix present,
# passing the correctness gate requires handling each matrix on its merits.
m = a.shape[0]
device = a.device
weights = torch.tensor(_MIXED_WEIGHTS, dtype=torch.float32, device=device)
labels = torch.multinomial(weights, m, replacement=True, generator=gen)
# Guarantee both a well-conditioned and an ill-conditioned matrix are present.
# (Only relevant for tiny batches; large batches get both with high prob.)
if m >= 2:
is_dense = labels == 0
if not bool(is_dense.any()):
labels[int(torch.randint(0, m, (1,), device=device, generator=gen))] = 0
elif bool(is_dense.all()):
pos = int(torch.randint(0, m, (1,), device=device, generator=gen))
labels[pos] = int(torch.randint(1, len(_MIXED_PROFILES), (1,), device=device, generator=gen))
# Process profiles in fixed order over the present labels so the RNG draws
# inside _apply_case are deterministic for a given seed.
for k, prof in enumerate(_MIXED_PROFILES):
mask = labels == k
if bool(mask.any()):
a[mask] = _apply_case(a[mask], prof, cond, gen)
return a


def generate_input(batch: int, n: int, cond: int, seed: int, case: str = "dense") -> input_t:
assert batch > 0, "batch must be positive"
assert n > 0, "n must be positive"
assert cond >= 0, "cond must be non-negative"

device = "cuda" if torch.cuda.is_available() else "cpu"
gen = torch.Generator(device=device)
gen.manual_seed(seed)

case = case.lower()
a = torch.randn((batch, n, n), device=device, dtype=torch.float32, generator=gen)

if case == "mixed":
a = _generate_mixed(a, cond, gen)
else:
a = _apply_case(a, case, cond, gen)

return a.contiguous()

Expand Down
21 changes: 21 additions & 0 deletions problems/linalg/qr_py/task.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ description: |
structure, such as rank-deficient, near-rank-deficient, banded, row-scaled,
near-collinear, upper-triangular, or clustered-scale inputs.

The `mixed` case builds a heterogeneous batch: each matrix is independently
assigned a conditioning profile (a well-conditioned dense majority interleaved
with the ill-conditioned stress structures above) at a random position in the
batch. This mirrors the real optimizer-statistics regime, where the per-layer
or per-block factors batched into one call have widely varying conditioning,
rather than all sharing one structure. The benchmark set (not just the test
set) now includes both `mixed` batches and fully ill-conditioned homogeneous
batches, so conditioning robustness is ranked, not only gated: an
implementation cannot inspect a few matrices, decide the whole batch is
well-conditioned, and route it to a path that is only valid for well-conditioned
inputs, and the runtime cost of the accurate path on hard inputs is part of the
score. Each matrix must be factored correctly on its own merits.

Correctness is a hard gate against the original FP32 input and the FP32
`torch.geqrf` compact-factor contract. Low-bit FP16, FP8, or NVFP4 work is
allowed only as an internal implementation strategy: returned factors must
Expand Down Expand Up @@ -89,6 +102,9 @@ tests:
- {"batch": 2, "n": 2048, "cond": 2, "seed": 224466, "case": "dense"}
- {"batch": 2, "n": 2048, "cond": 0, "seed": 224467, "case": "rankdef"}
- {"batch": 1, "n": 4096, "cond": 0, "seed": 75343, "case": "upper"}
- {"batch": 16, "n": 512, "cond": 2, "seed": 32530, "case": "mixed"}
- {"batch": 4, "n": 1024, "cond": 2, "seed": 4332, "case": "mixed"}
- {"batch": 2, "n": 2048, "cond": 2, "seed": 224468, "case": "mixed"}

benchmarks:
- {"batch": 20, "n": 32, "cond": 1, "seed": 43214}
Expand All @@ -98,3 +114,8 @@ benchmarks:
- {"batch": 60, "n": 1024, "cond": 2, "seed": 75342}
- {"batch": 8, "n": 2048, "cond": 1, "seed": 224466}
- {"batch": 2, "n": 4096, "cond": 1, "seed": 32412}
- {"batch": 640, "n": 512, "cond": 2, "seed": 770001, "case": "mixed"}
- {"batch": 60, "n": 1024, "cond": 2, "seed": 770002, "case": "mixed"}
- {"batch": 640, "n": 512, "cond": 0, "seed": 770003, "case": "rankdef"}
- {"batch": 640, "n": 512, "cond": 0, "seed": 770004, "case": "clustered"}
- {"batch": 60, "n": 1024, "cond": 0, "seed": 770005, "case": "nearrank"}