|
| 1 | +From f95e54328b15315c3563792fccae7193439d1312 Mon Sep 17 00:00:00 2001 |
| 2 | +From: AllSpark <allspark@microsoft.com> |
| 3 | +Date: Mon, 29 Sep 2025 19:34:39 +0000 |
| 4 | +Subject: [PATCH] inductor: guard bitwise shifts with max_shift and add tests |
| 5 | + for corner inputs; fixes ghissues 143555 and 143566 |
| 6 | + |
| 7 | +Signed-off-by: Azure Linux Security Servicing Account <azurelinux-security@microsoft.com> |
| 8 | +Upstream-reference: AI Backport of https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/143635.patch |
| 9 | +--- |
| 10 | + test/inductor/test_cpu_repro.py | 17 ++++++++++++++++++ |
| 11 | + torch/_inductor/codegen/cpp.py | 34 ++++++++++++++++++-- |
| 12 | + 2 files changed, 49 insertions(+), 2 deletions(-) |
| 13 | + |
| 14 | +diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py |
| 15 | +index 925c0f62..9ed4b79d 100644 |
| 16 | +--- a/test/inductor/test_cpu_repro.py |
| 17 | ++++ b/test/inductor/test_cpu_repro.py |
| 18 | +@@ -2315,6 +2315,23 @@ class CPUReproTests(TestCase): |
| 19 | + self.common(fn2, (x,)) |
| 20 | + assert metrics.generated_cpp_vec_kernel_count == 1 |
| 21 | + |
| 22 | ++ def test_bitwise_shift_corner_inputs(self): |
| 23 | ++ # Fix https://github.com/pytorch/pytorch/issues/143555 |
| 24 | ++ # and https://github.com/pytorch/pytorch/issues/143566 |
| 25 | ++ bitwise_fns = ( |
| 26 | ++ torch.bitwise_left_shift, |
| 27 | ++ torch.bitwise_right_shift, |
| 28 | ++ ) |
| 29 | ++ for bitwise_fn in bitwise_fns: |
| 30 | ++ torch._dynamo.reset() |
| 31 | ++ metrics.reset() |
| 32 | ++ x = torch.tensor(1000, dtype=torch.int64) |
| 33 | ++ bit_num = torch.tensor(64, dtype=torch.int64) |
| 34 | ++ res_aten_eager = bitwise_fn(x, bit_num) |
| 35 | ++ cfn = torch.compile(bitwise_fn) |
| 36 | ++ res = cfn(x, bit_num) |
| 37 | ++ self.assertEqual(res_aten_eager, res) |
| 38 | ++ |
| 39 | + def test_transpose_vertical_sum_cpu_only(self): |
| 40 | + def fn(a, b): |
| 41 | + c = a * b |
| 42 | +diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py |
| 43 | +index b94ede02..af5ed42a 100644 |
| 44 | +--- a/torch/_inductor/codegen/cpp.py |
| 45 | ++++ b/torch/_inductor/codegen/cpp.py |
| 46 | +@@ -801,11 +801,41 @@ class CppOverrides(OpOverrides): |
| 47 | + |
| 48 | + @staticmethod |
| 49 | + def bitwise_left_shift(a, b): |
| 50 | +- return f"decltype({a})({a} << {b})" |
| 51 | ++ code = BracesBuffer() |
| 52 | ++ code.writeline("[&]()") |
| 53 | ++ with code.indent(): |
| 54 | ++ scalar_t = DTYPE_TO_CPP[a.dtype] |
| 55 | ++ code.writeline( |
| 56 | ++ f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT;" |
| 57 | ++ ) |
| 58 | ++ code.writeline( |
| 59 | ++ f"if ((static_cast<std::make_signed_t<{scalar_t}>>({b}) < 0) || ({b} >= max_shift))" |
| 60 | ++ ) |
| 61 | ++ with code.indent(): |
| 62 | ++ code.writeline(f"return decltype({a})(0);") |
| 63 | ++ code.writeline( |
| 64 | ++ f"return decltype({a})(static_cast<std::make_unsigned_t<{scalar_t}>>({a}) << {b});" |
| 65 | ++ ) |
| 66 | ++ code.writeline("()") |
| 67 | ++ return code |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def bitwise_right_shift(a, b): |
| 71 | +- return f"decltype({a})({a} >> {b})" |
| 72 | ++ code = BracesBuffer() |
| 73 | ++ code.writeline("[&]()") |
| 74 | ++ with code.indent(): |
| 75 | ++ scalar_t = DTYPE_TO_CPP[a.dtype] |
| 76 | ++ code.writeline( |
| 77 | ++ f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT - std::is_signed_v<{scalar_t}>;" |
| 78 | ++ ) |
| 79 | ++ code.writeline( |
| 80 | ++ f"if ((static_cast<std::make_signed_t<{scalar_t}>>({b}) < 0) || ({b} >= max_shift))" |
| 81 | ++ ) |
| 82 | ++ with code.indent(): |
| 83 | ++ code.writeline(f"return decltype({a})({a} >> max_shift);") |
| 84 | ++ code.writeline(f"return decltype({a})({a} >> {b});") |
| 85 | ++ code.writeline("()") |
| 86 | ++ return code |
| 87 | + |
| 88 | + @staticmethod |
| 89 | + def rand(seed: sympy.Expr, offset: sympy.Expr): |
| 90 | +-- |
| 91 | +2.45.4 |
| 92 | + |
0 commit comments