Skip to content

Commit b5a6051

Browse files
committed
Added support for CONT frames in Websocket
1 parent e234a49 commit b5a6051

2 files changed

Lines changed: 191 additions & 34 deletions

File tree

adafruit_httpserver/exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,16 @@ def __init__(self, path: str) -> None:
6262
Creates a new ``FileNotExistsError`` for the file at ``path``.
6363
"""
6464
super().__init__(f"File does not exist: {path}")
65+
66+
67+
class WebsocketError(Exception):
68+
"""
69+
Raised when there is a error in WebSocket communication.
70+
"""
71+
72+
def __init__(self, message: str, code: int = None) -> None:
73+
"""
74+
Creates a new ``WebsocketError`` with the given ``message``.
75+
"""
76+
self.code = code
77+
super().__init__(f"WebSocket error: {message}")

adafruit_httpserver/response.py

Lines changed: 178 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from binascii import b2a_base64
1818
from errno import EAGAIN, ECONNRESET, ENOTCONN, ETIMEDOUT
19+
from time import monotonic_ns
1920

2021
try:
2122
try:
@@ -32,6 +33,7 @@
3233
BackslashInPathError,
3334
FileNotExistsError,
3435
ParentDirectoryReferenceError,
36+
WebsocketError,
3537
)
3638
from .headers import Headers
3739
from .interfaces import _ISocket
@@ -414,8 +416,8 @@ def __init__(
414416
*,
415417
permanent: bool = False,
416418
preserve_method: bool = False,
417-
status: Union[Status, Tuple[int, str]] = None,
418-
headers: Union[Headers, Dict[str, str]] = None,
419+
status: Union[Status, Tuple[int, str], None] = None,
420+
headers: Union[Headers, Dict[str, str], None] = None,
419421
cookies: Dict[str, str] = None,
420422
) -> None:
421423
"""
@@ -593,13 +595,21 @@ def route_func(request: Request):
593595
FIN = 0b10000000 # FIN bit indicating the final fragment
594596

595597
# opcodes
596-
CONT = 0 # Continuation frame, TODO: Currently not supported
598+
CONT = 0 # Continuation frame
597599
TEXT = 1 # Frame contains UTF-8 text
598600
BINARY = 2 # Frame contains binary data
599601
CLOSE = 8 # Frame closes the connection
600602
PING = 9 # Frame is a ping, expecting a pong
601603
PONG = 10 # Frame is a pong, in response to a ping
602604

605+
PROTOCOL_ERROR = 1002
606+
POLICY_VIOLATION = 1008
607+
MESSAGE_TOO_BIG = 1009
608+
609+
MESSAGE_MAX_SIZE_BYTES = 4096
610+
MESSAGE_MAX_FRAGMENTS = 16
611+
MESSAGE_FRAGMENT_TIMEOUT_NS = 5 * (10**9)
612+
603613
@staticmethod
604614
def _check_request_initiates_handshake(request: Request):
605615
if not all(
@@ -653,8 +663,52 @@ def __init__(
653663
self._buffer_size = buffer_size
654664
self.closed = False
655665

666+
self._reset_fragmented_message()
667+
656668
request.connection.setblocking(False)
657669

670+
def _start_fragmented_message(self, opcode: int | None, payload: bytes | None):
671+
if self.MESSAGE_MAX_FRAGMENTS < 2:
672+
raise WebsocketError("Too many fragments in message", self.POLICY_VIOLATION)
673+
674+
if len(payload) > self.MESSAGE_MAX_SIZE_BYTES:
675+
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
676+
677+
self._message_opcode: int | None = opcode
678+
self._message_payload: bytes | None = payload
679+
self._message_fragments: int | None = 1
680+
681+
now = monotonic_ns()
682+
self._message_start_timestamp: float | None = now
683+
self._message_last_frame_timestamp: float | None = now
684+
685+
def _reset_fragmented_message(self):
686+
self._message_opcode = None
687+
self._message_payload = None
688+
self._message_fragments = None
689+
self._message_start_timestamp = None
690+
self._message_last_frame_timestamp = None
691+
692+
def _cont_fragmented_message(self, payload: bytes):
693+
if self._message_fragments + 1 > self.MESSAGE_MAX_FRAGMENTS:
694+
raise WebsocketError("Too many fragments in message", self.POLICY_VIOLATION)
695+
696+
if len(self._message_payload) + len(payload) > self.MESSAGE_MAX_SIZE_BYTES:
697+
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
698+
699+
now = monotonic_ns()
700+
time_since_last_frame = now - self._message_last_frame_timestamp
701+
702+
if time_since_last_frame > self.MESSAGE_FRAGMENT_TIMEOUT_NS:
703+
raise WebsocketError("Fragment timeout exceeded", self.POLICY_VIOLATION)
704+
705+
self._message_payload += payload
706+
self._message_fragments += 1
707+
self._message_last_frame_timestamp = now
708+
709+
def _fragmented_message_in_progress(self):
710+
return self._message_opcode is not None
711+
658712
@staticmethod
659713
def _parse_frame_header(header):
660714
fin = header[0] & Websocket.FIN
@@ -677,13 +731,10 @@ def _read_frame(self):
677731

678732
fin, opcode, has_mask, length = self._parse_frame_header(header_bytes)
679733

680-
# TODO: Handle continuation frames, currently not supported
681-
if fin != Websocket.FIN and opcode == Websocket.CONT:
682-
return Websocket.CONT, None
734+
if not has_mask:
735+
raise WebsocketError("Client frame not masked", self.PROTOCOL_ERROR)
683736

684737
payload = b""
685-
if fin == Websocket.FIN and opcode == Websocket.CLOSE:
686-
return Websocket.CLOSE, payload
687738

688739
if length < 0:
689740
length = self._request.connection.recv_into(buffer, -length)
@@ -701,29 +752,95 @@ def _read_frame(self):
701752
if has_mask:
702753
payload = bytes(byte ^ mask[idx % 4] for idx, byte in enumerate(payload))
703754

704-
return opcode, payload
755+
return fin, opcode, payload
705756

706-
def _handle_frame(self, opcode: int, payload: bytes) -> Union[str, bytes, None]:
707-
# TODO: Handle continuation frames, currently not supported
708-
if opcode == Websocket.CONT:
709-
return None
757+
def _is_control_frame(self, opcode: int) -> bool:
758+
return opcode in (Websocket.CLOSE, Websocket.PING, Websocket.PONG)
759+
760+
def _handle_control_frame(self, fin: int, opcode: int, payload: bytes):
761+
if 125 < len(payload):
762+
raise WebsocketError("Control frame payload too large", self.PROTOCOL_ERROR)
763+
764+
if fin != Websocket.FIN:
765+
raise WebsocketError("Control frame not final", self.PROTOCOL_ERROR)
710766

711767
if opcode == Websocket.CLOSE:
712-
self.close()
713-
return None
768+
if len(payload) == 1:
769+
raise WebsocketError(
770+
"Invalid close payload length", self.PROTOCOL_ERROR
771+
)
772+
close_code = None
773+
close_reason = None
774+
if 2 <= len(payload):
775+
close_code = int.from_bytes(payload[:2], "big")
776+
if 2 < len(payload):
777+
try:
778+
close_reason = payload[2:].decode("utf-8")
779+
except UnicodeError as error:
780+
raise WebsocketError(
781+
"Invalid close reason encoding", self.PROTOCOL_ERROR
782+
) from error
783+
self.close(code=close_code, reason=close_reason)
784+
return
785+
elif opcode == Websocket.PING:
786+
self.send_message(payload, Websocket.PONG)
787+
return
788+
elif opcode == Websocket.PONG:
789+
return
790+
791+
def _handle_frame(
792+
self, fin: int, opcode: int, payload: bytes
793+
) -> Union[str, bytes, None]:
794+
795+
if self._is_control_frame(opcode):
796+
return self._handle_control_frame(fin, opcode, payload)
797+
798+
if not self._fragmented_message_in_progress():
799+
if opcode not in (Websocket.TEXT, Websocket.BINARY):
800+
raise WebsocketError(
801+
"Invalid frame received when no fragmented message in progress",
802+
self.PROTOCOL_ERROR,
803+
)
804+
805+
if fin == Websocket.FIN:
806+
if len(payload) > self.MESSAGE_MAX_SIZE_BYTES:
807+
raise WebsocketError("Message size too big", self.MESSAGE_TOO_BIG)
808+
809+
if opcode == Websocket.TEXT:
810+
try:
811+
return payload.decode("utf-8")
812+
except UnicodeError as error:
813+
raise WebsocketError(
814+
"Invalid UTF-8 in text message", self.PROTOCOL_ERROR
815+
) from error
816+
return payload
817+
818+
else:
819+
self._start_fragmented_message(opcode, payload)
820+
return None
821+
822+
if opcode not in (Websocket.CONT,):
823+
raise WebsocketError(
824+
"New data frame received while fragmented message in progress",
825+
self.PROTOCOL_ERROR,
826+
)
827+
828+
self._cont_fragmented_message(payload)
714829

715-
if opcode == Websocket.PONG:
830+
if fin != Websocket.FIN:
716831
return None
717-
if opcode == Websocket.PING:
718-
self.send_message(payload, Websocket.PONG)
719-
return payload
720832

721833
try:
722-
payload = payload.decode() if opcode == Websocket.TEXT else payload
723-
except UnicodeError:
724-
pass
725-
726-
return payload
834+
if self._message_opcode == Websocket.TEXT:
835+
try:
836+
return self._message_payload.decode("utf-8")
837+
except UnicodeError as error:
838+
raise WebsocketError(
839+
"Invalid UTF-8 in text message", self.PROTOCOL_ERROR
840+
) from error
841+
return self._message_payload
842+
finally:
843+
self._reset_fragmented_message()
727844

728845
def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]:
729846
"""
@@ -737,12 +854,24 @@ def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]:
737854
raise RuntimeError("Websocket connection is closed, cannot receive messages")
738855

739856
try:
740-
opcode, payload = self._read_frame()
741-
frame_data = self._handle_frame(opcode, payload)
857+
fin, opcode, payload = self._read_frame()
742858

743-
return frame_data
859+
return self._handle_frame(fin, opcode, payload)
860+
except WebsocketError as error:
861+
self.close(code=error.code or self.PROTOCOL_ERROR)
862+
return None
744863
except OSError as error:
745-
if error.errno == EAGAIN: # No messages available
864+
if error.errno == EAGAIN: # No message/frame available
865+
866+
if not self._fragmented_message_in_progress():
867+
return None
868+
869+
time_since_last_frame = (
870+
monotonic_ns() - self._message_last_frame_timestamp
871+
)
872+
873+
if time_since_last_frame > self.MESSAGE_FRAGMENT_TIMEOUT_NS:
874+
self.close(code=self.POLICY_VIOLATION)
746875
return None
747876
if error.errno == ETIMEDOUT: # Connection timed out
748877
return None
@@ -785,7 +914,7 @@ def send_message(
785914
"""
786915
Send a message to the client.
787916
788-
:param str message: Message to be sent.
917+
:param Union[str, bytes] message: Message to be sent.
789918
:param int opcode: Opcode of the message. Defaults to TEXT if message is a string and
790919
BINARY for bytes.
791920
:param bool fail_silently: If True, no error will be raised if the connection is closed.
@@ -814,13 +943,28 @@ def send_message(
814943
def _send(self) -> None:
815944
self._send_headers()
816945

817-
def close(self):
946+
def _prepare_close_payload(self, code: int = None, reason: str = None) -> bytes:
947+
if code is None:
948+
return b""
949+
payload = bytearray(code.to_bytes(2, "big"))
950+
if reason:
951+
payload.extend(reason.encode("utf-8"))
952+
if 125 < len(payload):
953+
payload = payload[:125]
954+
return bytes(payload)
955+
956+
def close(self, code: int = None, reason: str = None):
818957
"""
819958
Close the connection.
820959
821960
**Always call this method when you are done sending events.**
822961
"""
823-
if not self.closed:
824-
self.send_message(b"", Websocket.CLOSE, fail_silently=True)
825-
self._close_connection()
826-
self.closed = True
962+
if self.closed:
963+
return
964+
965+
self._reset_fragmented_message()
966+
967+
payload = self._prepare_close_payload(code, reason)
968+
self.send_message(payload, Websocket.CLOSE, fail_silently=True)
969+
self._close_connection()
970+
self.closed = True

0 commit comments

Comments
 (0)