|
| 1 | +From c27a81c2aee58626189631800841af0cc44e0873 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Kanishk-Bansal <kbkanishk975@gmail.com> |
| 3 | +Date: Wed, 23 Apr 2025 06:43:41 +0000 |
| 4 | +Subject: [PATCH] Address CVE-2025-32434 |
| 5 | + |
| 6 | +Upstream Patch Reference : https://github.com/pytorch/pytorch/commit/8d4b8a920a2172523deb95bf20e8e52d50649c04 |
| 7 | + |
| 8 | +Signed-off-by: Kanishk-Bansal <kbkanishk975@gmail.com> |
| 9 | +--- |
| 10 | + test/test_serialization.py | 6 +++++- |
| 11 | + torch/serialization.py | 17 ++++++++++++----- |
| 12 | + 2 files changed, 17 insertions(+), 6 deletions(-) |
| 13 | + |
| 14 | +diff --git a/test/test_serialization.py b/test/test_serialization.py |
| 15 | +index 9b9a7133..593f802a 100644 |
| 16 | +--- a/test/test_serialization.py |
| 17 | ++++ b/test/test_serialization.py |
| 18 | +@@ -404,7 +404,11 @@ class SerializationMixin: |
| 19 | + b += [a[0].storage()] |
| 20 | + b += [a[0].reshape(-1)[1:4].clone().storage()] |
| 21 | + path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') |
| 22 | +- c = torch.load(path, weights_only=weights_only) |
| 23 | ++ if weights_only: |
| 24 | ++ with self.assertRaisesRegex(RuntimeError, |
| 25 | ++ "Cannot use ``weights_only=True`` with files saved in the legacy .tar format."): |
| 26 | ++ c = torch.load(path, weights_only=weights_only) |
| 27 | ++ c = torch.load(path, weights_only=False) |
| 28 | + self.assertEqual(b, c, atol=0, rtol=0) |
| 29 | + self.assertTrue(isinstance(c[0], torch.FloatTensor)) |
| 30 | + self.assertTrue(isinstance(c[1], torch.FloatTensor)) |
| 31 | +diff --git a/torch/serialization.py b/torch/serialization.py |
| 32 | +index 83f6fa27..21ba1d07 100644 |
| 33 | +--- a/torch/serialization.py |
| 34 | ++++ b/torch/serialization.py |
| 35 | +@@ -33,6 +33,13 @@ STORAGE_KEY_SEPARATOR = ',' |
| 36 | + FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] |
| 37 | + MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] |
| 38 | + |
| 39 | ++UNSAFE_MESSAGE = ( |
| 40 | ++ "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " |
| 41 | ++ "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " |
| 42 | ++ "but it can result in arbitrary code execution. Do it only if you got the file from a " |
| 43 | ++ "trusted source." |
| 44 | ++ ) |
| 45 | ++ |
| 46 | + __all__ = [ |
| 47 | + 'SourceChangeWarning', |
| 48 | + 'mkdtemp', |
| 49 | +@@ -767,11 +774,6 @@ def load( |
| 50 | + >>> torch.load('module.pt', encoding='ascii') |
| 51 | + """ |
| 52 | + torch._C._log_api_usage_once("torch.load") |
| 53 | +- UNSAFE_MESSAGE = ( |
| 54 | +- "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" |
| 55 | +- " will likely succeed, but it can result in arbitrary code execution." |
| 56 | +- "Do it only if you get the file from a trusted source. WeightsUnpickler error: " |
| 57 | +- ) |
| 58 | + # Add ability to force safe only weight loads via environment variable |
| 59 | + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: |
| 60 | + weights_only = True |
| 61 | +@@ -900,6 +902,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): |
| 62 | + |
| 63 | + with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ |
| 64 | + mkdtemp() as tmpdir: |
| 65 | ++ if pickle_module is _weights_only_unpickler: |
| 66 | ++ raise RuntimeError( |
| 67 | ++ "Cannot use ``weights_only=True`` with files saved in the " |
| 68 | ++ "legacy .tar format. " + UNSAFE_MESSAGE |
| 69 | ++ ) |
| 70 | + |
| 71 | + tar.extract('storages', path=tmpdir) |
| 72 | + with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: |
| 73 | +-- |
| 74 | +2.45.2 |
| 75 | + |
0 commit comments