2020import json
2121
2222from .headers import Headers
23+ from .interfaces import _IFieldStorage , _IXSSSafeFieldStorage
2324
2425
25- class _IFieldStorage :
26- """Interface with shared methods for QueryParams and FormData."""
27-
28- _storage : Dict [str , List [Union [str , bytes ]]]
29-
30- def _add_field_value (self , field_name : str , value : Union [str , bytes ]) -> None :
31- if field_name not in self ._storage :
32- self ._storage [field_name ] = [value ]
33- else :
34- self ._storage [field_name ].append (value )
35-
36- @staticmethod
37- def _encode_html_entities (value : str ) -> str :
38- """Encodes unsafe HTML characters."""
39- return (
40- str (value )
41- .replace ("&" , "&" )
42- .replace ("<" , "<" )
43- .replace (">" , ">" )
44- .replace ('"' , """ )
45- .replace ("'" , "'" )
46- )
47-
48- def get (
49- self , field_name : str , default : Any = None , * , safe = True
50- ) -> Union [str , bytes , None ]:
51- """Get the value of a field."""
52- if safe :
53- return self ._encode_html_entities (
54- self ._storage .get (field_name , [default ])[0 ]
55- )
56-
57- _debug_warning_nonencoded_output ()
58- return self ._storage .get (field_name , [default ])[0 ]
59-
60- def get_list (self , field_name : str ) -> List [Union [str , bytes ]]:
61- """Get the list of values of a field."""
62- return self ._storage .get (field_name , [])
63-
64- @property
65- def fields (self ):
66- """Returns a list of field names."""
67- return list (self ._storage .keys ())
68-
69- def __getitem__ (self , field_name : str ):
70- return self .get (field_name )
71-
72- def __iter__ (self ):
73- return iter (self ._storage )
74-
75- def __len__ (self ):
76- return len (self ._storage )
77-
78- def __contains__ (self , key : str ):
79- return key in self ._storage
80-
81- def __repr__ (self ) -> str :
82- return f"{ self .__class__ .__name__ } ({ repr (self ._storage )} )"
83-
84-
85- class QueryParams (_IFieldStorage ):
26+ class QueryParams (_IXSSSafeFieldStorage ):
8627 """
8728 Class for parsing and storing GET query parameters requests.
8829
8930 Examples::
9031
9132 query_params = QueryParams("foo=bar&baz=qux&baz=quux")
92- # QueryParams({"foo": "bar", "baz": ["qux", "quux"]})
33+ # QueryParams({"foo": [ "bar"] , "baz": ["qux", "quux"]})
9334
9435 query_params.get("foo") # "bar"
9536 query_params["foo"] # "bar"
@@ -111,8 +52,80 @@ def __init__(self, query_string: str) -> None:
11152 elif query_param :
11253 self ._add_field_value (query_param , "" )
11354
55+ def _add_field_value (self , field_name : str , value : str ) -> None :
56+ super ()._add_field_value (field_name , value )
57+
58+ def get (
59+ self , field_name : str , default : str = None , * , safe = True
60+ ) -> Union [str , None ]:
61+ return super ().get (field_name , default , safe = safe )
62+
63+ def get_list (self , field_name : str , * , safe = True ) -> List [str ]:
64+ return super ().get_list (field_name , safe = safe )
65+
11466
115- class FormData (_IFieldStorage ):
67+ class File :
68+ """
69+ Class representing a file uploaded via POST.
70+
71+ Examples::
72+
73+ file = request.form_data.files.get("uploaded_file")
74+ # File(filename="foo.txt", content_type="text/plain", size=14)
75+
76+ file.content
77+ # "Hello, world!\\ n"
78+ """
79+
80+ filename : str
81+ """Filename of the file."""
82+
83+ content_type : str
84+ """Content type of the file."""
85+
86+ content : Union [str , bytes ]
87+ """Content of the file."""
88+
89+ def __init__ (
90+ self , filename : str , content_type : str , content : Union [str , bytes ]
91+ ) -> None :
92+ self .filename = filename
93+ self .content_type = content_type
94+ self .content = content
95+
96+ @property
97+ def size (self ) -> int :
98+ """Length of the file content."""
99+ return len (self .content )
100+
101+ def __repr__ (self ) -> str :
102+ filename , content_type , size = (
103+ repr (self .filename ),
104+ repr (self .content_type ),
105+ repr (self .size ),
106+ )
107+ return f"{ self .__class__ .__name__ } ({ filename = } , { content_type = } , { size = } )"
108+
109+
110+ class Files (_IFieldStorage ):
111+ """Class for files uploaded via POST."""
112+
113+ _storage : Dict [str , List [File ]]
114+
115+ def __init__ (self ) -> None :
116+ self ._storage = {}
117+
118+ def _add_field_value (self , field_name : str , value : File ) -> None :
119+ super ()._add_field_value (field_name , value )
120+
121+ def get (self , field_name : str , default : Any = None ) -> Union [File , Any , None ]:
122+ return super ().get (field_name , default )
123+
124+ def get_list (self , field_name : str ) -> List [File ]:
125+ return super ().get_list (field_name )
126+
127+
128+ class FormData (_IXSSSafeFieldStorage ):
116129 """
117130 Class for parsing and storing form data from POST requests.
118131
@@ -124,7 +137,7 @@ class FormData(_IFieldStorage):
124137 form_data = FormData(b"foo=bar&baz=qux&baz=quuz", "application/x-www-form-urlencoded")
125138 # or
126139 form_data = FormData(b"foo=bar\\ r\\ nbaz=qux\\ r\\ nbaz=quux", "text/plain")
127- # FormData({"foo": "bar", "baz": "qux"})
140+ # FormData({"foo": [ "bar"] , "baz": [ "qux", "quux"] })
128141
129142 form_data.get("foo") # "bar"
130143 form_data["foo"] # "bar"
@@ -135,10 +148,12 @@ class FormData(_IFieldStorage):
135148 """
136149
137150 _storage : Dict [str , List [Union [str , bytes ]]]
151+ files : Files
138152
139153 def __init__ (self , data : bytes , content_type : str ) -> None :
140154 self .content_type = content_type
141155 self ._storage = {}
156+ self .files = Files ()
142157
143158 if content_type .startswith ("application/x-www-form-urlencoded" ):
144159 self ._parse_x_www_form_urlencoded (data )
@@ -162,11 +177,25 @@ def _parse_multipart_form_data(self, data: bytes, boundary: str) -> None:
162177 blocks = data .split (b"--" + boundary .encode ())[1 :- 1 ]
163178
164179 for block in blocks :
165- disposition , content = block .split (b"\r \n \r \n " , 1 )
166- field_name = disposition .split (b'"' , 2 )[1 ].decode ()
167- value = content [:- 2 ]
180+ header_bytes , content_bytes = block .split (b"\r \n \r \n " , 1 )
181+ headers = Headers (header_bytes .decode ("utf-8" ).strip ())
168182
169- self ._add_field_value (field_name , value )
183+ field_name = headers .get_parameter ("Content-Disposition" , "name" )
184+ filename = headers .get_parameter ("Content-Disposition" , "filename" )
185+ content_type = headers .get_directive ("Content-Type" , "text/plain" )
186+ charset = headers .get_parameter ("Content-Type" , "charset" , "utf-8" )
187+
188+ content = content_bytes [:- 2 ] # remove trailing \r\n
189+ value = content .decode (charset ) if content_type == "text/plain" else content
190+
191+ # TODO: Other text content types (e.g. application/json) should be decoded as well and
192+
193+ if filename is not None :
194+ self .files ._add_field_value ( # pylint: disable=protected-access
195+ field_name , File (filename , content_type , value )
196+ )
197+ else :
198+ self ._add_field_value (field_name , value )
170199
171200 def _parse_text_plain (self , data : bytes ) -> None :
172201 lines = data .decode ("utf-8" ).split ("\r \n " )[:- 1 ]
@@ -176,6 +205,21 @@ def _parse_text_plain(self, data: bytes) -> None:
176205
177206 self ._add_field_value (field_name , value )
178207
208+ def _add_field_value (self , field_name : str , value : Union [str , bytes ]) -> None :
209+ super ()._add_field_value (field_name , value )
210+
211+ def get (
212+ self , field_name : str , default : Union [str , bytes ] = None , * , safe = True
213+ ) -> Union [str , bytes , None ]:
214+ return super ().get (field_name , default , safe = safe )
215+
216+ def get_list (self , field_name : str , * , safe = True ) -> List [Union [str , bytes ]]:
217+ return super ().get_list (field_name , safe = safe )
218+
219+ def __repr__ (self ) -> str :
220+ class_name = self .__class__ .__name__
221+ return f"{ class_name } ({ repr (self ._storage )} , files={ repr (self .files ._storage )} )"
222+
179223
180224class Request :
181225 """
@@ -358,12 +402,3 @@ def _parse_request_header(
358402 headers = Headers (headers_string )
359403
360404 return method , path , query_params , http_version , headers
361-
362-
363- def _debug_warning_nonencoded_output ():
364- """Warns about XSS risks."""
365- print (
366- "WARNING: Setting safe to False makes XSS vulnerabilities possible by "
367- "allowing access to raw untrusted values submitted by users. If this data is reflected "
368- "or shown within HTML without proper encoding it could enable Cross-Site Scripting."
369- )
0 commit comments