""" wsproto/extensions ~~~~~~~~~~~~~~~~~~ WebSocket extensions. """ from __future__ import annotations import zlib from abc import ABC, abstractmethod from typing import Optional from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits class Extension(ABC): name: str def enabled(self) -> bool: return False @abstractmethod def offer(self) -> bool | str: pass def accept(self, offer: str) -> bool | str | None: pass def finalize(self, offer: str) -> None: pass def frame_inbound_header( self, proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, payload_length: int, ) -> CloseReason | RsvBits: return RsvBits(False, False, False) def frame_inbound_payload_data( self, proto: FrameDecoder | FrameProtocol, data: bytes, ) -> bytes | CloseReason: return data def frame_inbound_complete( self, proto: FrameDecoder | FrameProtocol, fin: bool, ) -> bytes | CloseReason | None: pass def frame_outbound( self, proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, data: bytes, fin: bool, ) -> tuple[RsvBits, bytes]: return (rsv, data) class PerMessageDeflate(Extension): name = "permessage-deflate" DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 DEFAULT_SERVER_MAX_WINDOW_BITS = 15 def __init__( self, client_no_context_takeover: bool = False, client_max_window_bits: int | None = None, server_no_context_takeover: bool = False, server_max_window_bits: int | None = None, ) -> None: self.client_no_context_takeover = client_no_context_takeover self.server_no_context_takeover = server_no_context_takeover self._client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS self._server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS if client_max_window_bits is not None: self.client_max_window_bits = client_max_window_bits if server_max_window_bits is not None: self.server_max_window_bits = server_max_window_bits self._compressor: Optional[zlib._Compress] = None # noqa self._decompressor: Optional[zlib._Decompress] = None # noqa # This refers to the current frame self._inbound_is_compressible: bool | None = None # This refers to the ongoing message (which might span multiple # frames). Only the first frame in a fragmented message is flagged for # compression, so this carries that bit forward. self._inbound_compressed: bool | None = None self._enabled = False @property def client_max_window_bits(self) -> int: return self._client_max_window_bits @client_max_window_bits.setter def client_max_window_bits(self, value: int) -> None: if value < 9 or value > 15: msg = "Window size must be between 9 and 15 inclusive" raise ValueError(msg) self._client_max_window_bits = value @property def server_max_window_bits(self) -> int: return self._server_max_window_bits @server_max_window_bits.setter def server_max_window_bits(self, value: int) -> None: if value < 9 or value > 15: msg = "Window size must be between 9 and 15 inclusive" raise ValueError(msg) self._server_max_window_bits = value def _compressible_opcode(self, opcode: Opcode) -> bool: return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) def enabled(self) -> bool: return self._enabled def offer(self) -> bool | str: parameters = [ f"client_max_window_bits={self.client_max_window_bits}", f"server_max_window_bits={self.server_max_window_bits}", ] if self.client_no_context_takeover: parameters.append("client_no_context_takeover") if self.server_no_context_takeover: parameters.append("server_no_context_takeover") return "; ".join(parameters) def finalize(self, offer: str) -> None: bits = [b.strip() for b in offer.split(";")] for bit in bits[1:]: if bit.startswith("client_no_context_takeover"): self.client_no_context_takeover = True elif bit.startswith("server_no_context_takeover"): self.server_no_context_takeover = True elif bit.startswith("client_max_window_bits"): self.client_max_window_bits = int(bit.split("=", 1)[1].strip()) elif bit.startswith("server_max_window_bits"): self.server_max_window_bits = int(bit.split("=", 1)[1].strip()) self._enabled = True def _parse_params(self, params: str) -> tuple[int | None, int | None]: client_max_window_bits = None server_max_window_bits = None bits = [b.strip() for b in params.split(";")] for bit in bits[1:]: if bit.startswith("client_no_context_takeover"): self.client_no_context_takeover = True elif bit.startswith("server_no_context_takeover"): self.server_no_context_takeover = True elif bit.startswith("client_max_window_bits"): if "=" in bit: client_max_window_bits = int(bit.split("=", 1)[1].strip()) else: client_max_window_bits = self.client_max_window_bits elif bit.startswith("server_max_window_bits"): if "=" in bit: server_max_window_bits = int(bit.split("=", 1)[1].strip()) else: server_max_window_bits = self.server_max_window_bits return client_max_window_bits, server_max_window_bits def accept(self, offer: str) -> bool | None | str: client_max_window_bits, server_max_window_bits = self._parse_params(offer) parameters = [] if self.client_no_context_takeover: parameters.append("client_no_context_takeover") if self.server_no_context_takeover: parameters.append("server_no_context_takeover") try: if client_max_window_bits is not None: parameters.append(f"client_max_window_bits={client_max_window_bits}") self.client_max_window_bits = client_max_window_bits if server_max_window_bits is not None: parameters.append(f"server_max_window_bits={server_max_window_bits}") self.server_max_window_bits = server_max_window_bits except ValueError: return None else: self._enabled = True return "; ".join(parameters) def frame_inbound_header( self, proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, payload_length: int, ) -> CloseReason | RsvBits: if rsv.rsv1 and opcode.iscontrol(): return CloseReason.PROTOCOL_ERROR if rsv.rsv1 and opcode is Opcode.CONTINUATION: return CloseReason.PROTOCOL_ERROR self._inbound_is_compressible = self._compressible_opcode(opcode) if self._inbound_compressed is None: self._inbound_compressed = rsv.rsv1 if self._inbound_compressed: assert self._inbound_is_compressible if proto.client: bits = self.server_max_window_bits else: bits = self.client_max_window_bits if self._decompressor is None: self._decompressor = zlib.decompressobj(-int(bits)) return RsvBits(True, False, False) def frame_inbound_payload_data( self, proto: FrameDecoder | FrameProtocol, data: bytes, ) -> bytes | CloseReason: if not self._inbound_compressed or not self._inbound_is_compressible: return data assert self._decompressor is not None try: return self._decompressor.decompress(bytes(data)) except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA def frame_inbound_complete( self, proto: FrameDecoder | FrameProtocol, fin: bool, ) -> bytes | CloseReason | None: if not fin: return None if not self._inbound_is_compressible: self._inbound_compressed = None return None if not self._inbound_compressed: self._inbound_compressed = None return None assert self._decompressor is not None try: data = self._decompressor.decompress(b"\x00\x00\xff\xff") data += self._decompressor.flush() except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA if proto.client: no_context_takeover = self.server_no_context_takeover else: no_context_takeover = self.client_no_context_takeover if no_context_takeover: self._decompressor = None self._inbound_compressed = None return data def frame_outbound( self, proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, data: bytes, fin: bool, ) -> tuple[RsvBits, bytes]: if not self._compressible_opcode(opcode): return (rsv, data) if opcode is not Opcode.CONTINUATION: rsv = RsvBits(True, rsv[1], rsv[2]) if self._compressor is None: assert opcode is not Opcode.CONTINUATION if proto.client: bits = self.client_max_window_bits else: bits = self.server_max_window_bits self._compressor = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits), ) data = self._compressor.compress(bytes(data)) if fin: data += self._compressor.flush(zlib.Z_SYNC_FLUSH) data = data[:-4] if proto.client: no_context_takeover = self.client_no_context_takeover else: no_context_takeover = self.server_no_context_takeover if no_context_takeover: self._compressor = None return (rsv, data) def __repr__(self) -> str: descr = [f"client_max_window_bits={self.client_max_window_bits}"] if self.client_no_context_takeover: descr.append("client_no_context_takeover") descr.append(f"server_max_window_bits={self.server_max_window_bits}") if self.server_no_context_takeover: descr.append("server_no_context_takeover") return "<{} {}>".format(self.__class__.__name__, "; ".join(descr)) #: SUPPORTED_EXTENSIONS maps all supported extension names to their class. #: This can be used to iterate all supported extensions of wsproto, instantiate #: new extensions based on their name, or check if a given extension is #: supported or not. SUPPORTED_EXTENSIONS = {PerMessageDeflate.name: PerMessageDeflate}