|
| 1 | +From a8c49a5fac46df180ba95810dcbb56c00dbd9c76 Mon Sep 17 00:00:00 2001 |
| 2 | +From: sunflowersxu <166728538+sunflowersxu@users.noreply.github.com> |
| 3 | +Date: Thu, 13 Jun 2024 01:47:14 +0800 |
| 4 | +Subject: [PATCH] Mitigate tarball directory traversal risks (#6164) |
| 5 | + |
| 6 | +Hi, this pr is cleaner version than #6145 |
| 7 | + |
| 8 | +Signed-off-by: sunriseXu <15927176697@163.com> |
| 9 | +Co-authored-by: sunriseXu <15927176697@163.com> |
| 10 | +Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> |
| 11 | +--- |
| 12 | + third_party/onnx/onnx/hub.py | 43 +++++++++++++++++++++++++++++++++++- |
| 13 | + 1 file changed, 42 insertions(+), 1 deletion(-) |
| 14 | + |
| 15 | +diff --git a/third_party/onnx/onnx/hub.py b/third_party/onnx/onnx/hub.py |
| 16 | +index e5ca9e2c..dc888742 100644 |
| 17 | +--- a/third_party/onnx/onnx/hub.py |
| 18 | ++++ b/third_party/onnx/onnx/hub.py |
| 19 | +@@ -271,6 +271,35 @@ def load( |
| 20 | + return onnx.load(cast(IO[bytes], BytesIO(model_bytes))) |
| 21 | + |
| 22 | + |
| 23 | ++def _tar_members_filter(tar: tarfile.TarFile, base: str) -> list[tarfile.TarInfo]: |
| 24 | ++ """Check that the content of ``tar`` will be extracted safely |
| 25 | ++ |
| 26 | ++ Args: |
| 27 | ++ tar: The tarball file |
| 28 | ++ base: The directory where the tarball will be extracted |
| 29 | ++ |
| 30 | ++ Returns: |
| 31 | ++ list of tarball members |
| 32 | ++ """ |
| 33 | ++ result = [] |
| 34 | ++ for member in tar: |
| 35 | ++ member_path = os.path.join(base, member.name) |
| 36 | ++ abs_base = os.path.abspath(base) |
| 37 | ++ abs_member = os.path.abspath(member_path) |
| 38 | ++ if not abs_member.startswith(abs_base): |
| 39 | ++ raise RuntimeError( |
| 40 | ++ f"The tarball member {member_path} in downloading model contains " |
| 41 | ++ f"directory traversal sequence which may contain harmful payload." |
| 42 | ++ ) |
| 43 | ++ elif member.issym() or member.islnk(): |
| 44 | ++ raise RuntimeError( |
| 45 | ++ f"The tarball member {member_path} in downloading model contains " |
| 46 | ++ f"symbolic links which may contain harmful payload." |
| 47 | ++ ) |
| 48 | ++ result.append(member) |
| 49 | ++ return result |
| 50 | ++ |
| 51 | ++ |
| 52 | + def download_model_with_test_data( |
| 53 | + model: str, |
| 54 | + repo: str = "onnx/models:main", |
| 55 | +@@ -280,6 +309,7 @@ def download_model_with_test_data( |
| 56 | + ) -> Optional[str]: |
| 57 | + """ |
| 58 | + Downloads a model along with test data by name from the onnx model hub and returns the directory to which the files have been extracted. |
| 59 | ++ Users are responsible for making sure the model comes from a trusted source, and the data is safe to be extracted. |
| 60 | + |
| 61 | + :param model: The name of the onnx model in the manifest. This field is case-sensitive |
| 62 | + :param repo: The location of the model repo in format "user/repo[:branch]". |
| 63 | +@@ -342,7 +372,18 @@ def download_model_with_test_data( |
| 64 | + local_model_with_data_dir_path = local_model_with_data_path[ |
| 65 | + 0 : len(local_model_with_data_path) - 7 |
| 66 | + ] |
| 67 | +- model_with_data_zipped.extractall(local_model_with_data_dir_path) |
| 68 | ++ # Mitigate tarball directory traversal risks |
| 69 | ++ if hasattr(tarfile, "data_filter"): |
| 70 | ++ model_with_data_zipped.extractall( |
| 71 | ++ path=local_model_with_data_dir_path, filter="data" |
| 72 | ++ ) |
| 73 | ++ else: |
| 74 | ++ model_with_data_zipped.extractall( |
| 75 | ++ path=local_model_with_data_dir_path, |
| 76 | ++ members=_tar_members_filter( |
| 77 | ++ model_with_data_zipped, local_model_with_data_dir_path |
| 78 | ++ ), |
| 79 | ++ ) |
| 80 | + model_with_data_path = ( |
| 81 | + local_model_with_data_dir_path |
| 82 | + + "/" |
| 83 | +-- |
| 84 | +2.39.4 |
| 85 | + |
0 commit comments