diff --git a/adafruit_httpserver/exceptions.py b/adafruit_httpserver/exceptions.py index 13bba7e..c5ce11f 100644 --- a/adafruit_httpserver/exceptions.py +++ b/adafruit_httpserver/exceptions.py @@ -62,3 +62,16 @@ def __init__(self, path: str) -> None: Creates a new ``FileNotExistsError`` for the file at ``path``. """ super().__init__(f"File does not exist: {path}") + + +class WebsocketError(Exception): + """ + Raised when there is a error in WebSocket communication. + """ + + def __init__(self, message: str, code: int = None) -> None: + """ + Creates a new ``WebsocketError`` with the given ``message``. + """ + self.code = code + super().__init__(f"WebSocket error: {message}") diff --git a/adafruit_httpserver/response.py b/adafruit_httpserver/response.py index 21aa217..9affe6d 100644 --- a/adafruit_httpserver/response.py +++ b/adafruit_httpserver/response.py @@ -16,6 +16,7 @@ import os from binascii import b2a_base64 from errno import EAGAIN, ECONNRESET, ENOTCONN, ETIMEDOUT +from time import monotonic_ns try: try: @@ -32,6 +33,7 @@ BackslashInPathError, FileNotExistsError, ParentDirectoryReferenceError, + WebsocketError, ) from .headers import Headers from .interfaces import _ISocket @@ -414,8 +416,8 @@ def __init__( *, permanent: bool = False, preserve_method: bool = False, - status: Union[Status, Tuple[int, str]] = None, - headers: Union[Headers, Dict[str, str]] = None, + status: Union[Status, Tuple[int, str], None] = None, + headers: Union[Headers, Dict[str, str], None] = None, cookies: Dict[str, str] = None, ) -> None: """ @@ -593,13 +595,21 @@ def route_func(request: Request): FIN = 0b10000000 # FIN bit indicating the final fragment # opcodes - CONT = 0 # Continuation frame, TODO: Currently not supported + CONT = 0 # Continuation frame TEXT = 1 # Frame contains UTF-8 text BINARY = 2 # Frame contains binary data CLOSE = 8 # Frame closes the connection PING = 9 # Frame is a ping, expecting a pong PONG = 10 # Frame is a pong, in response to a ping + PROTOCOL_ERROR = 1002 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + + MESSAGE_MAX_SIZE_BYTES = 4096 + MESSAGE_MAX_FRAGMENTS = 16 + MESSAGE_FRAGMENT_TIMEOUT_NS = 5 * (10**9) + @staticmethod def _check_request_initiates_handshake(request: Request): if not all( @@ -650,80 +660,221 @@ def __init__( self._headers.setdefault("Connection", "Upgrade") self._headers.setdefault("Sec-WebSocket-Accept", sec_accept_key) self._headers.setdefault("Content-Type", None) - self._buffer_size = buffer_size + self._buffer = bytearray(buffer_size) self.closed = False + self._reset_fragmented_message() + request.connection.setblocking(False) - @staticmethod - def _parse_frame_header(header): - fin = header[0] & Websocket.FIN - opcode = header[0] & 0b00001111 - has_mask = header[1] & 0b10000000 - length = header[1] & 0b01111111 + def _start_fragmented_message(self, opcode: Union[int, None], payload: Union[bytes, None]): + if self.MESSAGE_MAX_FRAGMENTS < 2: + raise WebsocketError("Fragmented messages not allowed", self.POLICY_VIOLATION) - if length == 0b01111110: - length = -2 - elif length == 0b01111111: - length = -8 + if len(payload) > self.MESSAGE_MAX_SIZE_BYTES: + raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG) - return fin, opcode, has_mask, length + self._message_opcode: Union[int, None] = opcode + self._message_payload: Union[bytes, None] = payload + self._message_fragments: Union[int, None] = 1 - def _read_frame(self): - buffer = bytearray(self._buffer_size) + now = monotonic_ns() + self._message_start_timestamp: Union[float, None] = now + self._message_last_frame_timestamp: Union[float, None] = now + + def _reset_fragmented_message(self): + self._message_opcode = None + self._message_payload = None + self._message_fragments = None + self._message_start_timestamp = None + self._message_last_frame_timestamp = None + + def _cont_fragmented_message(self, payload: bytes): + if self._message_fragments + 1 > self.MESSAGE_MAX_FRAGMENTS: + raise WebsocketError("Too many fragments in message", self.POLICY_VIOLATION) + + if len(self._message_payload) + len(payload) > self.MESSAGE_MAX_SIZE_BYTES: + raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG) + + now = monotonic_ns() + time_since_last_frame = now - self._message_last_frame_timestamp + + if time_since_last_frame > self.MESSAGE_FRAGMENT_TIMEOUT_NS: + raise WebsocketError("Fragment timeout exceeded", self.POLICY_VIOLATION) + + self._message_payload += payload + self._message_fragments += 1 + self._message_last_frame_timestamp = now + + def _fragmented_message_in_progress(self): + return self._message_opcode is not None + + def _recv_exact(self, buffer: bytearray, size: int) -> bytes: + received = 0 + view = memoryview(buffer) + while received < size: + remaining = size - received + try: + count = self._request.connection.recv_into(view[received : received + remaining]) + except OSError as error: + if error.errno == EAGAIN and received == 0: + raise + if error.errno == EAGAIN: + continue + raise + if count == 0: + if received == 0: + raise OSError(ENOTCONN) + break + received += count + return bytes(view[:received]) + + def _read_frame_header(self): + header_bytes = self._recv_exact(self._buffer, 2) + + if len(header_bytes) < 2: + raise OSError(ENOTCONN) + + fin = header_bytes[0] & Websocket.FIN + opcode = header_bytes[0] & 0b00001111 + + mask = header_bytes[1] & 0b10000000 + if not mask: + raise WebsocketError("Client frame not masked", self.PROTOCOL_ERROR) + + payload_length = header_bytes[1] & 0b01111111 + + if 125 < payload_length: + if payload_length == 126: + payload_length_bytes = self._recv_exact(self._buffer, 2) # Read next 16 bits + if len(payload_length_bytes) < 2: + raise WebsocketError("Incomplete payload length", self.PROTOCOL_ERROR) + + elif payload_length == 127: + payload_length_bytes = self._recv_exact(self._buffer, 8) # Read next 64 bits + if len(payload_length_bytes) < 8: + raise WebsocketError("Incomplete payload length", self.PROTOCOL_ERROR) + else: + raise WebsocketError("Invalid payload length", self.PROTOCOL_ERROR) + + payload_length = int.from_bytes(payload_length_bytes, "big") + + # In 64-bit payload length, most significant bit must be 0 + if payload_length & (1 << 63): + raise WebsocketError("Invalid payload length", self.PROTOCOL_ERROR) + + if opcode == Websocket.CONT and self._fragmented_message_in_progress(): + if len(self._message_payload) + payload_length > self.MESSAGE_MAX_SIZE_BYTES: + raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG) + + elif opcode in {Websocket.TEXT, Websocket.BINARY}: + if payload_length > self.MESSAGE_MAX_SIZE_BYTES: + raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG) - header_length = self._request.connection.recv_into(buffer, 2) - header_bytes = buffer[:header_length] + return fin, opcode, payload_length - fin, opcode, has_mask, length = self._parse_frame_header(header_bytes) + def _read_frame(self): + fin, opcode, payload_length = self._read_frame_header() - # TODO: Handle continuation frames, currently not supported - if fin != Websocket.FIN and opcode == Websocket.CONT: - return Websocket.CONT, None + masking_key = self._recv_exact(self._buffer, 4) + if len(masking_key) < 4: + raise WebsocketError("Incomplete mask", self.PROTOCOL_ERROR) payload = b"" - if fin == Websocket.FIN and opcode == Websocket.CLOSE: - return Websocket.CLOSE, payload - if length < 0: - length = self._request.connection.recv_into(buffer, -length) - length = int.from_bytes(buffer[:length], "big") + while 0 < payload_length: + chunk = self._recv_exact(self._buffer, min(payload_length, len(self._buffer))) + if not chunk: + break + payload += chunk + payload_length -= len(chunk) - if has_mask: - mask_length = self._request.connection.recv_into(buffer, 4) - mask = buffer[:mask_length] + payload = bytes(byte ^ masking_key[idx % 4] for idx, byte in enumerate(payload)) - while 0 < length: - payload_length = self._request.connection.recv_into(buffer, length) - payload += buffer[: min(payload_length, length)] - length -= min(payload_length, length) + return fin, opcode, payload - if has_mask: - payload = bytes(byte ^ mask[idx % 4] for idx, byte in enumerate(payload)) + def _is_control_frame(self, opcode: int) -> bool: + return opcode in {Websocket.CLOSE, Websocket.PING, Websocket.PONG} - return opcode, payload + def _handle_control_frame(self, fin: int, opcode: int, payload: bytes): + if 125 < len(payload): + raise WebsocketError("Control frame payload too large", self.PROTOCOL_ERROR) - def _handle_frame(self, opcode: int, payload: bytes) -> Union[str, bytes, None]: - # TODO: Handle continuation frames, currently not supported - if opcode == Websocket.CONT: - return None + if fin != Websocket.FIN: + raise WebsocketError("Control frame not final", self.PROTOCOL_ERROR) if opcode == Websocket.CLOSE: - self.close() - return None + if len(payload) == 1: + raise WebsocketError("Invalid close payload length", self.PROTOCOL_ERROR) + close_code = None + close_reason = None + if 2 <= len(payload): + close_code = int.from_bytes(payload[:2], "big") + if 2 < len(payload): + try: + close_reason = payload[2:].decode("utf-8") + except UnicodeError as error: + raise WebsocketError( + "Invalid close reason encoding", self.PROTOCOL_ERROR + ) from error + self.close(code=close_code, reason=close_reason) + return + elif opcode == Websocket.PING: + self.send_message(payload, Websocket.PONG) + return + elif opcode == Websocket.PONG: + return + + def _handle_frame(self, fin: int, opcode: int, payload: bytes) -> Union[str, bytes, None]: + if self._is_control_frame(opcode): + return self._handle_control_frame(fin, opcode, payload) + + if not self._fragmented_message_in_progress(): + if opcode not in {Websocket.TEXT, Websocket.BINARY}: + raise WebsocketError( + "Invalid frame received when no fragmented message in progress", + self.PROTOCOL_ERROR, + ) + + if fin == Websocket.FIN: + if len(payload) > self.MESSAGE_MAX_SIZE_BYTES: + raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG) + + if opcode == Websocket.TEXT: + try: + return payload.decode("utf-8") + except UnicodeError as error: + raise WebsocketError( + "Invalid UTF-8 in text message", self.PROTOCOL_ERROR + ) from error + return payload + + else: + self._start_fragmented_message(opcode, payload) + return None + + if opcode != Websocket.CONT: + raise WebsocketError( + "New data frame received while fragmented message in progress", + self.PROTOCOL_ERROR, + ) - if opcode == Websocket.PONG: + self._cont_fragmented_message(payload) + + if fin != Websocket.FIN: return None - if opcode == Websocket.PING: - self.send_message(payload, Websocket.PONG) - return payload try: - payload = payload.decode() if opcode == Websocket.TEXT else payload - except UnicodeError: - pass - - return payload + if self._message_opcode == Websocket.TEXT: + try: + return self._message_payload.decode("utf-8") + except UnicodeError as error: + raise WebsocketError( + "Invalid UTF-8 in text message", self.PROTOCOL_ERROR + ) from error + return self._message_payload + finally: + self._reset_fragmented_message() def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]: """ @@ -737,12 +888,21 @@ def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]: raise RuntimeError("Websocket connection is closed, cannot receive messages") try: - opcode, payload = self._read_frame() - frame_data = self._handle_frame(opcode, payload) + fin, opcode, payload = self._read_frame() - return frame_data + return self._handle_frame(fin, opcode, payload) + except WebsocketError as error: + self.close(code=error.code or self.PROTOCOL_ERROR) + return None except OSError as error: - if error.errno == EAGAIN: # No messages available + if error.errno == EAGAIN: # No message/frame available + if not self._fragmented_message_in_progress(): + return None + + time_since_last_frame = monotonic_ns() - self._message_last_frame_timestamp + + if time_since_last_frame > self.MESSAGE_FRAGMENT_TIMEOUT_NS: + self.close(code=self.POLICY_VIOLATION) return None if error.errno == ETIMEDOUT: # Connection timed out return None @@ -785,7 +945,7 @@ def send_message( """ Send a message to the client. - :param str message: Message to be sent. + :param Union[str, bytes] message: Message to be sent. :param int opcode: Opcode of the message. Defaults to TEXT if message is a string and BINARY for bytes. :param bool fail_silently: If True, no error will be raised if the connection is closed. @@ -814,13 +974,31 @@ def send_message( def _send(self) -> None: self._send_headers() - def close(self): + def _prepare_close_payload(self, code: int = None, reason: str = None) -> bytes: + if code is None: + return b"" + payload = bytearray(code.to_bytes(2, "big")) + if reason: + payload.extend(reason.encode("utf-8")) + if 125 < len(payload): + payload = payload[:125] + return bytes(payload) + + def close(self, code: int = None, reason: str = None): """ Close the connection. **Always call this method when you are done sending events.** """ - if not self.closed: - self.send_message(b"", Websocket.CLOSE, fail_silently=True) - self._close_connection() - self.closed = True + if self.closed: + return + + self._reset_fragmented_message() + + payload = self._prepare_close_payload(code, reason) + try: + self.send_message(payload, Websocket.CLOSE, fail_silently=True) + except (BrokenPipeError, OSError): + pass + self._close_connection() + self.closed = True diff --git a/adafruit_httpserver/server.py b/adafruit_httpserver/server.py index e15a428..b0e3cc8 100644 --- a/adafruit_httpserver/server.py +++ b/adafruit_httpserver/server.py @@ -36,12 +36,8 @@ try: from ssl import SSLContext, create_default_context - try: # ssl imports for C python - from ssl import ( - CERT_NONE, - Purpose, - SSLError, - ) + try: # ssl imports for CPython + from ssl import CERT_NONE, Purpose, SSLError except ImportError: pass SSL_AVAILABLE = True @@ -129,7 +125,7 @@ def __init__( self._timeout = 1 self._auths = [] - self._routes: "List[Route]" = [] + self._routes: List[Route] = [] self.headers = Headers() self._socket_source = socket_source @@ -331,6 +327,8 @@ def _receive_header_bytes(self, sock: _ISocket) -> bytes: try: length = sock.recv_into(self._buffer, len(self._buffer)) received_bytes += self._buffer[:length] + except TimeoutError: + break except OSError as ex: if ex.errno == ETIMEDOUT: break @@ -350,6 +348,8 @@ def _receive_body_bytes( try: length = sock.recv_into(self._buffer, len(self._buffer)) received_body_bytes += self._buffer[:length] + except TimeoutError: + break except OSError as ex: if ex.errno == ETIMEDOUT: break @@ -647,7 +647,7 @@ def _debug_response_sent(response: "Response", time_elapsed: float): req_size = len(response._request.raw_request) status = response._status res_size = response._size - time_elapsed_ms = f"{round(time_elapsed*1000)}ms" + time_elapsed_ms = f"{round(time_elapsed * 1000)}ms" print( f'{client_ip} -- "{method} {path}" {req_size} -- "{status}" {res_size} -- {time_elapsed_ms}'