|
| 1 | +From c849ccbd342b6067d19d5805c6614a21a4f0b49f Mon Sep 17 00:00:00 2001 |
| 2 | +From: Sam Larsen <slarsen@meta.com> |
| 3 | +Date: Fri, 25 Jul 2025 09:31:15 -0700 |
| 4 | +Subject: [PATCH] Fix full_like decomposition to preserve strides (#158898) |
| 5 | + |
| 6 | +Summary: |
| 7 | + |
| 8 | +See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again. |
| 9 | + |
| 10 | +Upstream Patch Reference: https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159294.patch & https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/158898.patch |
| 11 | + |
| 12 | +--- |
| 13 | + test/inductor/test_torchinductor.py | 48 ++++++++++++++++++++++--- |
| 14 | + test/test_decomp.py | 11 +++++- |
| 15 | + torch/_inductor/decomposition.py | 55 +++++++++++++++++++++++++++++ |
| 16 | + torch/_inductor/lowering.py | 1 - |
| 17 | + 4 files changed, 109 insertions(+), 6 deletions(-) |
| 18 | + |
| 19 | +diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py |
| 20 | +index f0c152ad..850aa033 100644 |
| 21 | +--- a/test/inductor/test_torchinductor.py |
| 22 | ++++ b/test/inductor/test_torchinductor.py |
| 23 | +@@ -326,6 +326,10 @@ def check_model( |
| 24 | + reference_in_float=True, |
| 25 | + assert_equal=True, |
| 26 | + check_gradient=False, |
| 27 | ++ check_has_compiled=True, |
| 28 | ++ output_process_fn_grad=lambda x: x, |
| 29 | ++ # TODO: enable this for all tests |
| 30 | ++ exact_stride=False, |
| 31 | + ): |
| 32 | + kwargs = kwargs or {} |
| 33 | + torch._dynamo.reset() |
| 34 | +@@ -343,7 +347,12 @@ def check_model( |
| 35 | + x.dtype == torch.float16 or x.dtype == torch.bfloat16 |
| 36 | + ): |
| 37 | + has_lowp_args = True |
| 38 | +- return x.float() |
| 39 | ++ # Preserve strides when casting |
| 40 | ++ result = torch.empty_strided( |
| 41 | ++ x.size(), x.stride(), device=x.device, dtype=torch.float |
| 42 | ++ ) |
| 43 | ++ result.copy_(x) |
| 44 | ++ return result |
| 45 | + else: |
| 46 | + return x |
| 47 | + |
| 48 | +@@ -410,6 +419,7 @@ def check_model( |
| 49 | + rtol=rtol, |
| 50 | + equal_nan=True, |
| 51 | + exact_dtype=exact_dtype, |
| 52 | ++ exact_stride=exact_stride, |
| 53 | + ) |
| 54 | + # In case of input mutations, check that inputs are the same |
| 55 | + self.assertEqual( |
| 56 | +@@ -420,6 +430,7 @@ def check_model( |
| 57 | + equal_nan=True, |
| 58 | + # our testing sometimes uses higher precision inputs for the reference |
| 59 | + exact_dtype=False, |
| 60 | ++ exact_stride=exact_stride, |
| 61 | + ) |
| 62 | + else: |
| 63 | + for correct_val, actual_val in zip(correct_flat, actual_flat): |
| 64 | +@@ -430,6 +441,8 @@ def check_model( |
| 65 | + assert correct_val.layout == actual_val.layout |
| 66 | + if exact_dtype: |
| 67 | + assert correct_val.dtype == actual_val.dtype |
| 68 | ++ if exact_stride: |
| 69 | ++ assert correct_val.stride() == actual_val.stride() |
| 70 | + |
| 71 | + if check_gradient: |
| 72 | + |
| 73 | +@@ -452,6 +465,7 @@ def check_model( |
| 74 | + rtol=rtol, |
| 75 | + equal_nan=True, |
| 76 | + exact_dtype=exact_dtype, |
| 77 | ++ exact_stride=exact_stride, |
| 78 | + ) |
| 79 | + |
| 80 | + torch._dynamo.reset() |
| 81 | +@@ -501,6 +515,7 @@ def check_model_cuda( |
| 82 | + reference_in_float=reference_in_float, |
| 83 | + assert_equal=assert_equal, |
| 84 | + check_gradient=check_gradient, |
| 85 | ++ exact_stride=exact_stride, |
| 86 | + ) |
| 87 | + |
| 88 | + if check_lowp: |
| 89 | +@@ -529,6 +544,7 @@ def check_model_cuda( |
| 90 | + reference_in_float=reference_in_float, |
| 91 | + assert_equal=assert_equal, |
| 92 | + check_gradient=check_gradient, |
| 93 | ++ exact_stride=exact_stride, |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +@@ -3500,6 +3516,18 @@ class CommonTemplate: |
| 98 | + |
| 99 | + self.common(fn, (torch.randn(8),)) |
| 100 | + |
| 101 | ++ def test_full_like_transposed(self): |
| 102 | ++ def fn(a): |
| 103 | ++ return torch.full_like(a, 3) |
| 104 | ++ |
| 105 | ++ self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True) |
| 106 | ++ |
| 107 | ++ def test_full_like_sliced(self): |
| 108 | ++ def fn(a): |
| 109 | ++ return torch.full_like(a, 3) |
| 110 | ++ |
| 111 | ++ self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True) |
| 112 | ++ |
| 113 | + def test_full_truncation(self): |
| 114 | + def fn(a): |
| 115 | + return a + torch.full_like(a, 7.777) |
| 116 | +@@ -4767,14 +4795,26 @@ class CommonTemplate: |
| 117 | + model = Model() |
| 118 | + x = torch.rand(10, 3, 0) |
| 119 | + |
| 120 | +- self.common(model, (x,)) |
| 121 | ++ self.common(model, (x,), exact_stride=True) |
| 122 | + |
| 123 | + @config.patch(fallback_random=True) |
| 124 | + def test_like_rands(self): |
| 125 | + def fn(x): |
| 126 | +- return torch.rand_like(x), torch.randn_like(x) |
| 127 | ++ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11) |
| 128 | ++ |
| 129 | ++ self.common(fn, [torch.zeros([20, 20])], exact_stride=True) |
| 130 | ++ |
| 131 | ++ @config.patch(fallback_random=True) |
| 132 | ++ @xfail_if_mps # 100% are not close |
| 133 | ++ def test_like_rands_sliced(self): |
| 134 | ++ def fn(x): |
| 135 | ++ return ( |
| 136 | ++ torch.randn_like(x), |
| 137 | ++ torch.randn_like(x), |
| 138 | ++ torch.randint_like(x, 1, 11), |
| 139 | ++ ) |
| 140 | + |
| 141 | +- self.common(fn, [torch.zeros([20, 20])]) |
| 142 | ++ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) |
| 143 | + |
| 144 | + def test_max_pool2d_with_indices_backward(self): |
| 145 | + def fn(a, b, c): |
| 146 | +diff --git a/test/test_decomp.py b/test/test_decomp.py |
| 147 | +index c27ffadb..58cfa40c 100644 |
| 148 | +--- a/test/test_decomp.py |
| 149 | ++++ b/test/test_decomp.py |
| 150 | +@@ -524,7 +524,16 @@ class TestDecomp(TestCase): |
| 151 | + assert len(real_out) == len(decomp_out) |
| 152 | + |
| 153 | + if do_relative_check: |
| 154 | +- upcast = partial(upcast_tensor, dtype=torch.float64) |
| 155 | ++ device_arg = kwargs.get("device", None) |
| 156 | ++ |
| 157 | ++ def upcast(x): |
| 158 | ++ if (isinstance(x, Tensor) and x.device.type == "mps") or ( |
| 159 | ++ device_arg and torch.device(device_arg).type == "mps" |
| 160 | ++ ): |
| 161 | ++ return upcast_tensor(x, dtype=torch.float32) |
| 162 | ++ else: |
| 163 | ++ return upcast_tensor(x, dtype=torch.float64) |
| 164 | ++ |
| 165 | + real_out_double, _ = tree_flatten( |
| 166 | + func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) |
| 167 | + ) |
| 168 | +diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py |
| 169 | +index 3fa3640e..16efb345 100644 |
| 170 | +--- a/torch/_inductor/decomposition.py |
| 171 | ++++ b/torch/_inductor/decomposition.py |
| 172 | +@@ -218,6 +218,61 @@ def should_pad_bench(mat1, mat2, op, input=None): |
| 173 | + # TODO: Build a learned model which would be better than this heuristic |
| 174 | + return ori_time > pad_time * 1.1 |
| 175 | + |
| 176 | ++def _get_shape_permutation_like( |
| 177 | ++ self: torch.Tensor, |
| 178 | ++) -> tuple[utils.ShapeType, utils.StrideType]: |
| 179 | ++ physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self) |
| 180 | ++ shape = [self.shape[l] for l in physical_layout] |
| 181 | ++ |
| 182 | ++ permutation = [0] * len(shape) |
| 183 | ++ for p, l in enumerate(physical_layout): |
| 184 | ++ permutation[l] = p |
| 185 | ++ |
| 186 | ++ return (shape, permutation) |
| 187 | ++ |
| 188 | ++ |
| 189 | ++ if memory_format != torch.preserve_format: |
| 190 | ++ result = torch.full( |
| 191 | ++ self.shape, |
| 192 | ++ fill_value, |
| 193 | ++ dtype=dtype, |
| 194 | ++ layout=layout, |
| 195 | ++ device=device, |
| 196 | ++ pin_memory=pin_memory, |
| 197 | ++ requires_grad=requires_grad, |
| 198 | ++ ) |
| 199 | ++ return result.to(memory_format=memory_format) |
| 200 | ++ |
| 201 | ++ else: |
| 202 | ++ assert layout == torch.strided |
| 203 | ++ shape, permutation = _get_shape_permutation_like(self) |
| 204 | ++ result = torch.full( |
| 205 | ++ shape, |
| 206 | ++ fill_value, |
| 207 | ++ dtype=dtype, |
| 208 | ++ layout=layout, |
| 209 | ++ device=device, |
| 210 | ++ pin_memory=pin_memory, |
| 211 | ++ requires_grad=requires_grad, |
| 212 | ++ ) |
| 213 | ++ if permutation == list(range(len(permutation))): |
| 214 | ++ return result |
| 215 | ++ return result.permute(permutation).clone() |
| 216 | ++ |
| 217 | ++@register_decomposition(aten.rand_like) |
| 218 | ++def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 219 | ++ return _rand_like(torch.rand, self, **kwargs) |
| 220 | ++ |
| 221 | ++ |
| 222 | ++@register_decomposition(aten.randn_like) |
| 223 | ++def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 224 | ++ return _rand_like(torch.randn, self, **kwargs) |
| 225 | ++ |
| 226 | ++ |
| 227 | ++@register_decomposition(aten.randint_like.default) |
| 228 | ++def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: |
| 229 | ++ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) |
| 230 | ++ |
| 231 | + |
| 232 | + @register_decomposition([aten.mm]) |
| 233 | + def mm_decomp(mat1, mat2): |
| 234 | +diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py |
| 235 | +index 4c77ebdf..858c6866 100644 |
| 236 | +--- a/torch/_inductor/lowering.py |
| 237 | ++++ b/torch/_inductor/lowering.py |
| 238 | +@@ -1712,7 +1712,6 @@ def _full(fill_value, device, dtype, size): |
| 239 | + ) |
| 240 | + |
| 241 | + |
| 242 | +-@register_lowering(aten.full_like, type_promotion_kind=None) |
| 243 | + def full_like(x, fill_value, **kwargs): |
| 244 | + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) |
| 245 | + |
| 246 | +-- |
| 247 | +2.45.4 |
| 248 | + |
0 commit comments