Skip to content

Commit b30e433

Browse files
committed
Refactor of buffer data receiving and added more checks websocket errors
1 parent e521ba6 commit b30e433

1 file changed

Lines changed: 81 additions & 39 deletions

File tree

adafruit_httpserver/response.py

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -660,27 +660,27 @@ def __init__(
660660
self._headers.setdefault("Connection", "Upgrade")
661661
self._headers.setdefault("Sec-WebSocket-Accept", sec_accept_key)
662662
self._headers.setdefault("Content-Type", None)
663-
self._buffer_size = buffer_size
663+
self._buffer = bytearray(buffer_size)
664664
self.closed = False
665665

666666
self._reset_fragmented_message()
667667

668668
request.connection.setblocking(False)
669669

670-
def _start_fragmented_message(self, opcode: int | None, payload: bytes | None):
670+
def _start_fragmented_message(self, opcode: Union[int, None], payload: Union[bytes, None]):
671671
if self.MESSAGE_MAX_FRAGMENTS < 2:
672672
raise WebsocketError("Too many fragments in message", self.POLICY_VIOLATION)
673673

674674
if len(payload) > self.MESSAGE_MAX_SIZE_BYTES:
675675
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
676676

677-
self._message_opcode: int | None = opcode
678-
self._message_payload: bytes | None = payload
679-
self._message_fragments: int | None = 1
677+
self._message_opcode: Union[int, None] = opcode
678+
self._message_payload: Union[bytes, None] = payload
679+
self._message_fragments: Union[int, None] = 1
680680

681681
now = monotonic_ns()
682-
self._message_start_timestamp: float | None = now
683-
self._message_last_frame_timestamp: float | None = now
682+
self._message_start_timestamp: Union[float, None] = now
683+
self._message_last_frame_timestamp: Union[float, None] = now
684684

685685
def _reset_fragmented_message(self):
686686
self._message_opcode = None
@@ -709,48 +709,87 @@ def _cont_fragmented_message(self, payload: bytes):
709709
def _fragmented_message_in_progress(self):
710710
return self._message_opcode is not None
711711

712-
@staticmethod
713-
def _parse_frame_header(header):
714-
fin = header[0] & Websocket.FIN
715-
opcode = header[0] & 0b00001111
716-
has_mask = header[1] & 0b10000000
717-
length = header[1] & 0b01111111
712+
def _recv_exact(self, buffer: bytearray, size: int) -> bytes:
713+
received = 0
714+
view = memoryview(buffer)
715+
while received < size:
716+
remaining = size - received
717+
try:
718+
count = self._request.connection.recv_into(view[received : received + remaining])
719+
except OSError as error:
720+
if error.errno == EAGAIN and received == 0:
721+
raise
722+
if error.errno == EAGAIN:
723+
continue
724+
raise
725+
if count == 0:
726+
if received == 0:
727+
raise OSError(ENOTCONN)
728+
break
729+
received += count
730+
return bytes(view[:received])
718731

719-
if length == 0b01111110:
720-
length = -2
721-
elif length == 0b01111111:
722-
length = -8
732+
def _read_frame_header(self):
733+
header_bytes = self._recv_exact(self._buffer, 2)
723734

724-
return fin, opcode, has_mask, length
735+
if len(header_bytes) < 2:
736+
raise OSError(ENOTCONN)
725737

726-
def _read_frame(self):
727-
buffer = bytearray(self._buffer_size)
738+
fin = header_bytes[0] & Websocket.FIN
739+
opcode = header_bytes[0] & 0b00001111
728740

729-
header_length = self._request.connection.recv_into(buffer, 2)
730-
header_bytes = buffer[:header_length]
741+
mask = header_bytes[1] & 0b10000000
742+
if not mask:
743+
raise WebsocketError("Client frame not masked", self.PROTOCOL_ERROR)
731744

732-
fin, opcode, has_mask, length = self._parse_frame_header(header_bytes)
745+
payload_length = header_bytes[1] & 0b01111111
733746

734-
if not has_mask:
735-
raise WebsocketError("Client frame not masked", self.PROTOCOL_ERROR)
747+
if 125 < payload_length:
748+
if payload_length == 126:
749+
payload_length_bytes = self._recv_exact(self._buffer, 2) # Read next 16 bits
750+
if len(payload_length_bytes) < 2:
751+
raise WebsocketError("Incomplete payload length", self.PROTOCOL_ERROR)
736752

737-
payload = b""
753+
elif payload_length == 127:
754+
payload_length_bytes = self._recv_exact(self._buffer, 8) # Read next 64 bits
755+
if len(payload_length_bytes) < 8:
756+
raise WebsocketError("Incomplete payload length", self.PROTOCOL_ERROR)
757+
else:
758+
raise WebsocketError("Invalid payload length", self.PROTOCOL_ERROR)
738759

739-
if length < 0:
740-
length = self._request.connection.recv_into(buffer, -length)
741-
length = int.from_bytes(buffer[:length], "big")
760+
payload_length = int.from_bytes(payload_length_bytes, "big")
742761

743-
if has_mask:
744-
mask_length = self._request.connection.recv_into(buffer, 4)
745-
mask = buffer[:mask_length]
762+
# In 64-bit payload length, most significant bit must be 0
763+
if payload_length & (1 << 63):
764+
raise WebsocketError("Invalid payload length", self.PROTOCOL_ERROR)
746765

747-
while 0 < length:
748-
payload_length = self._request.connection.recv_into(buffer, length)
749-
payload += buffer[: min(payload_length, length)]
750-
length -= min(payload_length, length)
766+
if opcode == Websocket.CONT and self._fragmented_message_in_progress():
767+
if len(self._message_payload) + payload_length > self.MESSAGE_MAX_SIZE_BYTES:
768+
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
751769

752-
if has_mask:
753-
payload = bytes(byte ^ mask[idx % 4] for idx, byte in enumerate(payload))
770+
elif opcode in {Websocket.TEXT, Websocket.BINARY}:
771+
if payload_length > self.MESSAGE_MAX_SIZE_BYTES:
772+
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
773+
774+
return fin, opcode, payload_length
775+
776+
def _read_frame(self):
777+
fin, opcode, payload_length = self._read_frame_header()
778+
779+
masking_key = self._recv_exact(self._buffer, 4)
780+
if len(masking_key) < 4:
781+
raise WebsocketError("Incomplete mask", self.PROTOCOL_ERROR)
782+
783+
payload = b""
784+
785+
while 0 < payload_length:
786+
chunk = self._recv_exact(self._buffer, min(payload_length, len(self._buffer)))
787+
if not chunk:
788+
break
789+
payload += chunk
790+
payload_length -= len(chunk)
791+
792+
payload = bytes(byte ^ masking_key[idx % 4] for idx, byte in enumerate(payload))
754793

755794
return fin, opcode, payload
756795

@@ -957,6 +996,9 @@ def close(self, code: int = None, reason: str = None):
957996
self._reset_fragmented_message()
958997

959998
payload = self._prepare_close_payload(code, reason)
960-
self.send_message(payload, Websocket.CLOSE, fail_silently=True)
999+
try:
1000+
self.send_message(payload, Websocket.CLOSE, fail_silently=True)
1001+
except (BrokenPipeError, OSError):
1002+
pass
9611003
self._close_connection()
9621004
self.closed = True

0 commit comments

Comments
 (0)