from __future__ import annotations import hashlib import socket import struct from typing import BinaryIO, Final from construct import ConstructError, Int32ul, PascalString, PrefixedArray, Struct as CStruct from nacl.bindings.crypto_secretstream import ( crypto_secretstream_xchacha20poly1305_ABYTES, crypto_secretstream_xchacha20poly1305_HEADERBYTES, crypto_secretstream_xchacha20poly1305_KEYBYTES, crypto_secretstream_xchacha20poly1305_TAG_FINAL, crypto_secretstream_xchacha20poly1305_init_pull, crypto_secretstream_xchacha20poly1305_pull, crypto_secretstream_xchacha20poly1305_state, ) KEY_DERIVATION_PERSON: Final = b"vv-secretstream1" SERVER_HOST: Final = "127.0.0.1" SERVER_PORT: Final = 9000 SHARED_SECRET: Final = "change-me" MAX_MESSAGE_SIZE: Final = 64 * 1024 * 1024 TRANSPORT_CHUNK_SIZE: Final = 4096 CHUNK_USED_BYTES: Final = struct.Struct("!H") CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size CIPHERTEXT_CHUNK_SIZE: Final = ( TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES ) MESSAGE_LENGTH: Final = struct.Struct("!I") class ProtocolError(RuntimeError): pass MessageInChat = CStruct( "role" / PascalString(Int32ul, "utf8"), "content" / PascalString(Int32ul, "utf8"), ) IncomingChatBatch = CStruct( "messages" / PrefixedArray(Int32ul, MessageInChat), ) def derive_key(shared_secret: str) -> bytes: return hashlib.blake2b( shared_secret.encode("utf-8"), digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES, person=KEY_DERIVATION_PERSON, ).digest() def read_exact(reader: BinaryIO | SecretStreamReader, size: int) -> bytes: data = bytearray() while len(data) < size: chunk = reader.read(size - len(data)) if chunk == b"": raise EOFError("Connection closed.") data.extend(chunk) return bytes(data) class SecretStreamReader: def __init__(self, reader: BinaryIO, key: bytes): self.reader = reader self.state = crypto_secretstream_xchacha20poly1305_state() self.buffer = bytearray() self.seen_final = False header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES) crypto_secretstream_xchacha20poly1305_init_pull(self.state, header, key) def peek(self, size: int = 1) -> bytes: while len(self.buffer) < size and not self.seen_final: self._read_chunk() return bytes(self.buffer[:size]) def read(self, size: int = -1) -> bytes: while len(self.buffer) < size and not self.seen_final: self._read_chunk() data = bytes(self.buffer[:size]) del self.buffer[:size] return data def _read_chunk(self) -> None: ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE) plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(self.state, ciphertext) if len(plaintext) != TRANSPORT_CHUNK_SIZE: raise ProtocolError( f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}." ) used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0] if used > CHUNK_DATA_SIZE: raise ProtocolError( f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}." ) start = CHUNK_USED_BYTES.size end = start + used self.buffer.extend(plaintext[start:end]) if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL: self.seen_final = True def recv_batch(stream: SecretStreamReader) -> object: size_bytes = read_exact(stream, MESSAGE_LENGTH.size) message_size = MESSAGE_LENGTH.unpack(size_bytes)[0] if message_size > MAX_MESSAGE_SIZE: raise ProtocolError( f"Application message size {message_size} exceeds limit {MAX_MESSAGE_SIZE}." ) payload = read_exact(stream, message_size) return IncomingChatBatch.parse(payload) def print_batch(peer: str, batch: object) -> None: print(f"[packet] {peer} sent {len(batch.messages)} message(s)", flush=True) for index, message in enumerate(batch.messages): print( f" [{index}] role={message.role!r} content={message.content!r}", flush=True, ) def handle_client(client_sock: socket.socket, address: tuple[str, int], key: bytes) -> None: peer = f"{address[0]}:{address[1]}" with client_sock: with client_sock.makefile("rb") as raw_reader: stream = SecretStreamReader(raw_reader, key) print(f"[connected] {peer}", flush=True) while True: if stream.peek(1) == b"": print(f"[disconnected] {peer}", flush=True) return batch = recv_batch(stream) print_batch(peer, batch) def run_server(host: str, port: int, shared_secret: str) -> None: key = derive_key(shared_secret) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock: server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_sock.bind((host, port)) server_sock.listen() print(f"[listening] {host}:{port}", flush=True) while True: client_sock, address = server_sock.accept() try: handle_client(client_sock, address, key) except (ConstructError, ProtocolError) as exc: print(f"[protocol error] {address[0]}:{address[1]}: {exc}", flush=True) except Exception as exc: print(f"[error] {address[0]}:{address[1]}: {exc}", flush=True) if __name__ == "__main__": try: run_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET) except KeyboardInterrupt: print("\n[stopped]", flush=True)