@@ -660,27 +660,27 @@ def __init__(
660660 self ._headers .setdefault ("Connection" , "Upgrade" )
661661 self ._headers .setdefault ("Sec-WebSocket-Accept" , sec_accept_key )
662662 self ._headers .setdefault ("Content-Type" , None )
663- self ._buffer_size = buffer_size
663+ self ._buffer = bytearray ( buffer_size )
664664 self .closed = False
665665
666666 self ._reset_fragmented_message ()
667667
668668 request .connection .setblocking (False )
669669
670- def _start_fragmented_message (self , opcode : int | None , payload : bytes | None ):
670+ def _start_fragmented_message (self , opcode : Union [ int , None ] , payload : Union [ bytes , None ] ):
671671 if self .MESSAGE_MAX_FRAGMENTS < 2 :
672672 raise WebsocketError ("Too many fragments in message" , self .POLICY_VIOLATION )
673673
674674 if len (payload ) > self .MESSAGE_MAX_SIZE_BYTES :
675675 raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
676676
677- self ._message_opcode : int | None = opcode
678- self ._message_payload : bytes | None = payload
679- self ._message_fragments : int | None = 1
677+ self ._message_opcode : Union [ int , None ] = opcode
678+ self ._message_payload : Union [ bytes , None ] = payload
679+ self ._message_fragments : Union [ int , None ] = 1
680680
681681 now = monotonic_ns ()
682- self ._message_start_timestamp : float | None = now
683- self ._message_last_frame_timestamp : float | None = now
682+ self ._message_start_timestamp : Union [ float , None ] = now
683+ self ._message_last_frame_timestamp : Union [ float , None ] = now
684684
685685 def _reset_fragmented_message (self ):
686686 self ._message_opcode = None
@@ -709,48 +709,87 @@ def _cont_fragmented_message(self, payload: bytes):
709709 def _fragmented_message_in_progress (self ):
710710 return self ._message_opcode is not None
711711
712- @staticmethod
713- def _parse_frame_header (header ):
714- fin = header [0 ] & Websocket .FIN
715- opcode = header [0 ] & 0b00001111
716- has_mask = header [1 ] & 0b10000000
717- length = header [1 ] & 0b01111111
712+ def _recv_exact (self , buffer : bytearray , size : int ) -> bytes :
713+ received = 0
714+ view = memoryview (buffer )
715+ while received < size :
716+ remaining = size - received
717+ try :
718+ count = self ._request .connection .recv_into (view [received : received + remaining ])
719+ except OSError as error :
720+ if error .errno == EAGAIN and received == 0 :
721+ raise
722+ if error .errno == EAGAIN :
723+ continue
724+ raise
725+ if count == 0 :
726+ if received == 0 :
727+ raise OSError (ENOTCONN )
728+ break
729+ received += count
730+ return bytes (view [:received ])
718731
719- if length == 0b01111110 :
720- length = - 2
721- elif length == 0b01111111 :
722- length = - 8
732+ def _read_frame_header (self ):
733+ header_bytes = self ._recv_exact (self ._buffer , 2 )
723734
724- return fin , opcode , has_mask , length
735+ if len (header_bytes ) < 2 :
736+ raise OSError (ENOTCONN )
725737
726- def _read_frame ( self ):
727- buffer = bytearray ( self . _buffer_size )
738+ fin = header_bytes [ 0 ] & Websocket . FIN
739+ opcode = header_bytes [ 0 ] & 0b00001111
728740
729- header_length = self ._request .connection .recv_into (buffer , 2 )
730- header_bytes = buffer [:header_length ]
741+ mask = header_bytes [1 ] & 0b10000000
742+ if not mask :
743+ raise WebsocketError ("Client frame not masked" , self .PROTOCOL_ERROR )
731744
732- fin , opcode , has_mask , length = self . _parse_frame_header ( header_bytes )
745+ payload_length = header_bytes [ 1 ] & 0b01111111
733746
734- if not has_mask :
735- raise WebsocketError ("Client frame not masked" , self .PROTOCOL_ERROR )
747+ if 125 < payload_length :
748+ if payload_length == 126 :
749+ payload_length_bytes = self ._recv_exact (self ._buffer , 2 ) # Read next 16 bits
750+ if len (payload_length_bytes ) < 2 :
751+ raise WebsocketError ("Incomplete payload length" , self .PROTOCOL_ERROR )
736752
737- payload = b""
753+ elif payload_length == 127 :
754+ payload_length_bytes = self ._recv_exact (self ._buffer , 8 ) # Read next 64 bits
755+ if len (payload_length_bytes ) < 8 :
756+ raise WebsocketError ("Incomplete payload length" , self .PROTOCOL_ERROR )
757+ else :
758+ raise WebsocketError ("Invalid payload length" , self .PROTOCOL_ERROR )
738759
739- if length < 0 :
740- length = self ._request .connection .recv_into (buffer , - length )
741- length = int .from_bytes (buffer [:length ], "big" )
760+ payload_length = int .from_bytes (payload_length_bytes , "big" )
742761
743- if has_mask :
744- mask_length = self . _request . connection . recv_into ( buffer , 4 )
745- mask = buffer [: mask_length ]
762+ # In 64-bit payload length, most significant bit must be 0
763+ if payload_length & ( 1 << 63 ):
764+ raise WebsocketError ( "Invalid payload length" , self . PROTOCOL_ERROR )
746765
747- while 0 < length :
748- payload_length = self ._request .connection .recv_into (buffer , length )
749- payload += buffer [: min (payload_length , length )]
750- length -= min (payload_length , length )
766+ if opcode == Websocket .CONT and self ._fragmented_message_in_progress ():
767+ if len (self ._message_payload ) + payload_length > self .MESSAGE_MAX_SIZE_BYTES :
768+ raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
751769
752- if has_mask :
753- payload = bytes (byte ^ mask [idx % 4 ] for idx , byte in enumerate (payload ))
770+ elif opcode in {Websocket .TEXT , Websocket .BINARY }:
771+ if payload_length > self .MESSAGE_MAX_SIZE_BYTES :
772+ raise WebsocketError ("Message size too big" , self .MESSAGE_TOO_BIG )
773+
774+ return fin , opcode , payload_length
775+
776+ def _read_frame (self ):
777+ fin , opcode , payload_length = self ._read_frame_header ()
778+
779+ masking_key = self ._recv_exact (self ._buffer , 4 )
780+ if len (masking_key ) < 4 :
781+ raise WebsocketError ("Incomplete mask" , self .PROTOCOL_ERROR )
782+
783+ payload = b""
784+
785+ while 0 < payload_length :
786+ chunk = self ._recv_exact (self ._buffer , min (payload_length , len (self ._buffer )))
787+ if not chunk :
788+ break
789+ payload += chunk
790+ payload_length -= len (chunk )
791+
792+ payload = bytes (byte ^ masking_key [idx % 4 ] for idx , byte in enumerate (payload ))
754793
755794 return fin , opcode , payload
756795
@@ -957,6 +996,9 @@ def close(self, code: int = None, reason: str = None):
957996 self ._reset_fragmented_message ()
958997
959998 payload = self ._prepare_close_payload (code , reason )
960- self .send_message (payload , Websocket .CLOSE , fail_silently = True )
999+ try :
1000+ self .send_message (payload , Websocket .CLOSE , fail_silently = True )
1001+ except (BrokenPipeError , OSError ):
1002+ pass
9611003 self ._close_connection ()
9621004 self .closed = True
0 commit comments