from __future__ import annotations import hashlib import socket import struct from typing import BinaryIO, Callable, Final from nacl.bindings.crypto_secretstream import ( crypto_secretstream_xchacha20poly1305_ABYTES, crypto_secretstream_xchacha20poly1305_HEADERBYTES, crypto_secretstream_xchacha20poly1305_KEYBYTES, crypto_secretstream_xchacha20poly1305_TAG_FINAL, crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, crypto_secretstream_xchacha20poly1305_init_pull, crypto_secretstream_xchacha20poly1305_init_push, crypto_secretstream_xchacha20poly1305_pull, crypto_secretstream_xchacha20poly1305_push, crypto_secretstream_xchacha20poly1305_state, ) FrameCallback = Callable[[bytes], None] KEY_DERIVATION_PERSON: Final = b"vv-secretstream1" MAX_FRAME_SIZE: Final = 64 * 1024 * 1024 TRANSPORT_CHUNK_SIZE: Final = 4096 CHUNK_USED_BYTES: Final = struct.Struct("!H") CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size assert CHUNK_DATA_SIZE > 0 CIPHERTEXT_CHUNK_SIZE: Final = ( TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES ) FRAME_LENGTH: Final = struct.Struct("!I") class ProtocolError(RuntimeError): pass def derive_key(shared_secret: str) -> bytes: return hashlib.blake2b( shared_secret.encode("utf-8"), digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES, person=KEY_DERIVATION_PERSON, ).digest() def read_exact(reader: BinaryIO | SecretStreamSocket, size: int) -> bytes: data = bytearray() while len(data) < size: chunk = reader.read(size - len(data)) if chunk == b"": raise EOFError("Connection closed.") data.extend(chunk) return bytes(data) def wrap_connection_socket( connection_socket: socket.socket, shared_secret: str, ) -> SecretStreamSocket: key = derive_key(shared_secret) reader = connection_socket.makefile("rb") writer = connection_socket.makefile("wb") push_state = crypto_secretstream_xchacha20poly1305_state() pull_state = crypto_secretstream_xchacha20poly1305_state() try: header = crypto_secretstream_xchacha20poly1305_init_push(push_state, key) writer.write(header) writer.flush() peer_header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES) crypto_secretstream_xchacha20poly1305_init_pull(pull_state, peer_header, key) return SecretStreamSocket( sock=connection_socket, reader=reader, writer=writer, push_state=push_state, pull_state=pull_state, ) except Exception: # Python sucks, exceptions suck writer.close() reader.close() connection_socket.close() raise class SecretStreamSocket: def __init__( self, sock: socket.socket, reader: BinaryIO, writer: BinaryIO, push_state: crypto_secretstream_xchacha20poly1305_state, pull_state: crypto_secretstream_xchacha20poly1305_state, ): self.sock = sock self.reader = reader self.writer = writer self.push_state = push_state self.pull_state = pull_state self.buffer = bytearray() self.seen_final = False self.sent_final = False def __enter__(self) -> SecretStreamSocket: return self def __exit__(self, exc_type, exc, tb) -> None: self.close() def send_frame(self, frame: bytes) -> None: if self.sent_final: raise RuntimeError("Cannot send after final chunk was sent.") if len(frame) > MAX_FRAME_SIZE: raise ValueError(f"Frame size {len(frame)} exceeds limit {MAX_FRAME_SIZE}.") plaintext = FRAME_LENGTH.pack(len(frame)) + frame start = 0 while start < len(plaintext): used = min(CHUNK_DATA_SIZE, len(plaintext) - start) self._send_transport_chunk( plaintext[start : start + used], tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, ) start += used self.writer.flush() def recv_frame(self) -> bytes | None: if self._peek_plain(1) == b"": return None length_bytes = read_exact(self, FRAME_LENGTH.size) frame_size = FRAME_LENGTH.unpack(length_bytes)[0] if frame_size > MAX_FRAME_SIZE: raise ProtocolError(f"Frame size {frame_size} exceeds limit {MAX_FRAME_SIZE}.") return read_exact(self, frame_size) def run_receiving_loop(self, on_frame: FrameCallback) -> None: while True: frame = self.recv_frame() if frame is None: return on_frame(frame) def read(self, size) -> bytes: while len(self.buffer) < size and not self.seen_final: self._read_transport_chunk() data = bytes(self.buffer[:size]) del self.buffer[:size] return data def close(self, send_final: bool = True) -> None: if send_final and not self.sent_final: try: self._send_transport_chunk( b"", tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL, ) self.writer.flush() except (OSError, ValueError): pass finally: self.sent_final = True try: self.writer.close() except (OSError, ValueError): pass try: self.reader.close() except (OSError, ValueError): pass self.sock.close() def _peek_plain(self, size: int = 1) -> bytes: while len(self.buffer) < size and not self.seen_final: self._read_transport_chunk() return bytes(self.buffer[:size]) def _read_transport_chunk(self) -> None: ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE) plaintext, tag = crypto_secretstream_xchacha20poly1305_pull( self.pull_state, ciphertext ) if len(plaintext) != TRANSPORT_CHUNK_SIZE: raise ProtocolError( f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}." ) used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0] if used > CHUNK_DATA_SIZE: raise ProtocolError( f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}." ) start = CHUNK_USED_BYTES.size end = start + used self.buffer.extend(plaintext[start:end]) if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL: self.seen_final = True def _send_transport_chunk(self, payload: bytes, tag: int) -> None: used = len(payload) if used > CHUNK_DATA_SIZE: raise ValueError(f"Chunk payload size {used} exceeds limit {CHUNK_DATA_SIZE}.") plaintext = bytearray(TRANSPORT_CHUNK_SIZE) plaintext[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used) data_start = CHUNK_USED_BYTES.size data_end = data_start + used plaintext[data_start:data_end] = payload ciphertext = crypto_secretstream_xchacha20poly1305_push( self.push_state, bytes(plaintext), tag=tag, ) self.writer.write(ciphertext)