1616import os
1717from binascii import b2a_base64
1818from errno import EAGAIN , ECONNRESET , ENOTCONN , ETIMEDOUT
19+ from time import monotonic_ns
1920
2021try :
2122 try :
3233 BackslashInPathError ,
3334 FileNotExistsError ,
3435 ParentDirectoryReferenceError ,
36+ WebsocketError ,
3537)
3638from .headers import Headers
3739from .interfaces import _ISocket
@@ -414,8 +416,8 @@ def __init__(
414416 * ,
415417 permanent : bool = False ,
416418 preserve_method : bool = False ,
417- status : Union [Status , Tuple [int , str ]] = None ,
418- headers : Union [Headers , Dict [str , str ]] = None ,
419+ status : Union [Status , Tuple [int , str ], None ] = None ,
420+ headers : Union [Headers , Dict [str , str ], None ] = None ,
419421 cookies : Dict [str , str ] = None ,
420422 ) -> None :
421423 """
@@ -593,13 +595,21 @@ def route_func(request: Request):
593595 FIN = 0b10000000 # FIN bit indicating the final fragment
594596
595597 # opcodes
596- CONT = 0 # Continuation frame, TODO: Currently not supported
598+ CONT = 0 # Continuation frame
597599 TEXT = 1 # Frame contains UTF-8 text
598600 BINARY = 2 # Frame contains binary data
599601 CLOSE = 8 # Frame closes the connection
600602 PING = 9 # Frame is a ping, expecting a pong
601603 PONG = 10 # Frame is a pong, in response to a ping
602604
605+ PROTOCOL_ERROR = 1002
606+ POLICY_VIOLATION = 1008
607+ MESSAGE_TOO_BIG = 1009
608+
609+ MESSAGE_MAX_SIZE_BYTES = 4096
610+ MESSAGE_MAX_FRAGMENTS = 16
611+ MESSAGE_FRAGMENT_TIMEOUT_NS = 5 * (10 ** 9 )
612+
603613 @staticmethod
604614 def _check_request_initiates_handshake (request : Request ):
605615 if not all (
@@ -653,8 +663,52 @@ def __init__(
653663 self ._buffer_size = buffer_size
654664 self .closed = False
655665
666+ self ._reset_fragmented_message ()
667+
656668 request .connection .setblocking (False )
657669
670+ def _start_fragmented_message (self , opcode : int | None , payload : bytes | None ):
671+ if self .MESSAGE_MAX_FRAGMENTS < 2 :
672+ raise WebsocketError ("Too many fragments in message" , self .POLICY_VIOLATION )
673+
674+ if len (payload ) > self .MESSAGE_MAX_SIZE_BYTES :
675+ raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
676+
677+ self ._message_opcode : int | None = opcode
678+ self ._message_payload : bytes | None = payload
679+ self ._message_fragments : int | None = 1
680+
681+ now = monotonic_ns ()
682+ self ._message_start_timestamp : float | None = now
683+ self ._message_last_frame_timestamp : float | None = now
684+
685+ def _reset_fragmented_message (self ):
686+ self ._message_opcode = None
687+ self ._message_payload = None
688+ self ._message_fragments = None
689+ self ._message_start_timestamp = None
690+ self ._message_last_frame_timestamp = None
691+
692+ def _cont_fragmented_message (self , payload : bytes ):
693+ if self ._message_fragments + 1 > self .MESSAGE_MAX_FRAGMENTS :
694+ raise WebsocketError ("Too many fragments in message" , self .POLICY_VIOLATION )
695+
696+ if len (self ._message_payload ) + len (payload ) > self .MESSAGE_MAX_SIZE_BYTES :
697+ raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
698+
699+ now = monotonic_ns ()
700+ time_since_last_frame = now - self ._message_last_frame_timestamp
701+
702+ if time_since_last_frame > self .MESSAGE_FRAGMENT_TIMEOUT_NS :
703+ raise WebsocketError ("Fragment timeout exceeded" , self .POLICY_VIOLATION )
704+
705+ self ._message_payload += payload
706+ self ._message_fragments += 1
707+ self ._message_last_frame_timestamp = now
708+
709+ def _fragmented_message_in_progress (self ):
710+ return self ._message_opcode is not None
711+
658712 @staticmethod
659713 def _parse_frame_header (header ):
660714 fin = header [0 ] & Websocket .FIN
@@ -677,13 +731,10 @@ def _read_frame(self):
677731
678732 fin , opcode , has_mask , length = self ._parse_frame_header (header_bytes )
679733
680- # TODO: Handle continuation frames, currently not supported
681- if fin != Websocket .FIN and opcode == Websocket .CONT :
682- return Websocket .CONT , None
734+ if not has_mask :
735+ raise WebsocketError ("Client frame not masked" , self .PROTOCOL_ERROR )
683736
684737 payload = b""
685- if fin == Websocket .FIN and opcode == Websocket .CLOSE :
686- return Websocket .CLOSE , payload
687738
688739 if length < 0 :
689740 length = self ._request .connection .recv_into (buffer , - length )
@@ -701,29 +752,95 @@ def _read_frame(self):
701752 if has_mask :
702753 payload = bytes (byte ^ mask [idx % 4 ] for idx , byte in enumerate (payload ))
703754
704- return opcode , payload
755+ return fin , opcode , payload
705756
706- def _handle_frame (self , opcode : int , payload : bytes ) -> Union [str , bytes , None ]:
707- # TODO: Handle continuation frames, currently not supported
708- if opcode == Websocket .CONT :
709- return None
757+ def _is_control_frame (self , opcode : int ) -> bool :
758+ return opcode in (Websocket .CLOSE , Websocket .PING , Websocket .PONG )
759+
760+ def _handle_control_frame (self , fin : int , opcode : int , payload : bytes ):
761+ if 125 < len (payload ):
762+ raise WebsocketError ("Control frame payload too large" , self .PROTOCOL_ERROR )
763+
764+ if fin != Websocket .FIN :
765+ raise WebsocketError ("Control frame not final" , self .PROTOCOL_ERROR )
710766
711767 if opcode == Websocket .CLOSE :
712- self .close ()
713- return None
768+ if len (payload ) == 1 :
769+ raise WebsocketError (
770+ "Invalid close payload length" , self .PROTOCOL_ERROR
771+ )
772+ close_code = None
773+ close_reason = None
774+ if 2 <= len (payload ):
775+ close_code = int .from_bytes (payload [:2 ], "big" )
776+ if 2 < len (payload ):
777+ try :
778+ close_reason = payload [2 :].decode ("utf-8" )
779+ except UnicodeError as error :
780+ raise WebsocketError (
781+ "Invalid close reason encoding" , self .PROTOCOL_ERROR
782+ ) from error
783+ self .close (code = close_code , reason = close_reason )
784+ return
785+ elif opcode == Websocket .PING :
786+ self .send_message (payload , Websocket .PONG )
787+ return
788+ elif opcode == Websocket .PONG :
789+ return
790+
791+ def _handle_frame (
792+ self , fin : int , opcode : int , payload : bytes
793+ ) -> Union [str , bytes , None ]:
794+
795+ if self ._is_control_frame (opcode ):
796+ return self ._handle_control_frame (fin , opcode , payload )
797+
798+ if not self ._fragmented_message_in_progress ():
799+ if opcode not in (Websocket .TEXT , Websocket .BINARY ):
800+ raise WebsocketError (
801+ "Invalid frame received when no fragmented message in progress" ,
802+ self .PROTOCOL_ERROR ,
803+ )
804+
805+ if fin == Websocket .FIN :
806+ if len (payload ) > self .MESSAGE_MAX_SIZE_BYTES :
807+ raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
808+
809+ if opcode == Websocket .TEXT :
810+ try :
811+ return payload .decode ("utf-8" )
812+ except UnicodeError as error :
813+ raise WebsocketError (
814+ "Invalid UTF-8 in text message" , self .PROTOCOL_ERROR
815+ ) from error
816+ return payload
817+
818+ else :
819+ self ._start_fragmented_message (opcode , payload )
820+ return None
821+
822+ if opcode not in (Websocket .CONT ,):
823+ raise WebsocketError (
824+ "New data frame received while fragmented message in progress" ,
825+ self .PROTOCOL_ERROR ,
826+ )
827+
828+ self ._cont_fragmented_message (payload )
714829
715- if opcode == Websocket .PONG :
830+ if fin != Websocket .FIN :
716831 return None
717- if opcode == Websocket .PING :
718- self .send_message (payload , Websocket .PONG )
719- return payload
720832
721833 try :
722- payload = payload .decode () if opcode == Websocket .TEXT else payload
723- except UnicodeError :
724- pass
725-
726- return payload
834+ if self ._message_opcode == Websocket .TEXT :
835+ try :
836+ return self ._message_payload .decode ("utf-8" )
837+ except UnicodeError as error :
838+ raise WebsocketError (
839+ "Invalid UTF-8 in text message" , self .PROTOCOL_ERROR
840+ ) from error
841+ return self ._message_payload
842+ finally :
843+ self ._reset_fragmented_message ()
727844
728845 def receive (self , fail_silently : bool = False ) -> Union [str , bytes , None ]:
729846 """
@@ -737,12 +854,24 @@ def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]:
737854 raise RuntimeError ("Websocket connection is closed, cannot receive messages" )
738855
739856 try :
740- opcode , payload = self ._read_frame ()
741- frame_data = self ._handle_frame (opcode , payload )
857+ fin , opcode , payload = self ._read_frame ()
742858
743- return frame_data
859+ return self ._handle_frame (fin , opcode , payload )
860+ except WebsocketError as error :
861+ self .close (code = error .code or self .PROTOCOL_ERROR )
862+ return None
744863 except OSError as error :
745- if error .errno == EAGAIN : # No messages available
864+ if error .errno == EAGAIN : # No message/frame available
865+
866+ if not self ._fragmented_message_in_progress ():
867+ return None
868+
869+ time_since_last_frame = (
870+ monotonic_ns () - self ._message_last_frame_timestamp
871+ )
872+
873+ if time_since_last_frame > self .MESSAGE_FRAGMENT_TIMEOUT_NS :
874+ self .close (code = self .POLICY_VIOLATION )
746875 return None
747876 if error .errno == ETIMEDOUT : # Connection timed out
748877 return None
@@ -785,7 +914,7 @@ def send_message(
785914 """
786915 Send a message to the client.
787916
788- :param str message: Message to be sent.
917+ :param Union[ str, bytes] message: Message to be sent.
789918 :param int opcode: Opcode of the message. Defaults to TEXT if message is a string and
790919 BINARY for bytes.
791920 :param bool fail_silently: If True, no error will be raised if the connection is closed.
@@ -814,13 +943,28 @@ def send_message(
814943 def _send (self ) -> None :
815944 self ._send_headers ()
816945
817- def close (self ):
946+ def _prepare_close_payload (self , code : int = None , reason : str = None ) -> bytes :
947+ if code is None :
948+ return b""
949+ payload = bytearray (code .to_bytes (2 , "big" ))
950+ if reason :
951+ payload .extend (reason .encode ("utf-8" ))
952+ if 125 < len (payload ):
953+ payload = payload [:125 ]
954+ return bytes (payload )
955+
956+ def close (self , code : int = None , reason : str = None ):
818957 """
819958 Close the connection.
820959
821960 **Always call this method when you are done sending events.**
822961 """
823- if not self .closed :
824- self .send_message (b"" , Websocket .CLOSE , fail_silently = True )
825- self ._close_connection ()
826- self .closed = True
962+ if self .closed :
963+ return
964+
965+ self ._reset_fragmented_message ()
966+
967+ payload = self ._prepare_close_payload (code , reason )
968+ self .send_message (payload , Websocket .CLOSE , fail_silently = True )
969+ self ._close_connection ()
970+ self .closed = True
0 commit comments