from __future__ import annotations import asyncio import hashlib import struct from typing import Callable, Final from nacl.bindings.crypto_secretstream import ( crypto_secretstream_xchacha20poly1305_ABYTES, crypto_secretstream_xchacha20poly1305_HEADERBYTES, crypto_secretstream_xchacha20poly1305_KEYBYTES, crypto_secretstream_xchacha20poly1305_TAG_FINAL, crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, crypto_secretstream_xchacha20poly1305_init_pull, crypto_secretstream_xchacha20poly1305_init_push, crypto_secretstream_xchacha20poly1305_pull, crypto_secretstream_xchacha20poly1305_push, crypto_secretstream_xchacha20poly1305_state, ) FrameCallback = Callable[[bytes], None] KEY_DERIVATION_PERSON: Final = b"vv-secretstream1" MAX_FRAME_SIZE: Final = 64 * 1024 * 1024 TRANSPORT_CHUNK_SIZE: Final = 4096 CHUNK_USED_BYTES: Final = struct.Struct("!H") CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size assert CHUNK_DATA_SIZE > 0 CIPHERTEXT_CHUNK_SIZE: Final = ( TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES ) FRAME_LENGTH: Final = struct.Struct("!I") class ProtocolError(RuntimeError): pass def derive_key(shared_secret: str) -> bytes: return hashlib.blake2b( shared_secret.encode("utf-8"), digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES, person=KEY_DERIVATION_PERSON, ).digest() async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes: data = await reader.readexactly(size) if data == b"": raise EOFError("Connection closed.") return data async def wrap_connection_socket( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, shared_secret: str, ) -> SecretStreamSocket: key = derive_key(shared_secret) push_state = crypto_secretstream_xchacha20poly1305_state() pull_state = crypto_secretstream_xchacha20poly1305_state() try: header = crypto_secretstream_xchacha20poly1305_init_push(push_state, key) writer.write(header) await writer.drain() peer_header = await read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES) crypto_secretstream_xchacha20poly1305_init_pull(pull_state, peer_header, key) return SecretStreamSocket( reader=reader, writer=writer, push_state=push_state, pull_state=pull_state, ) except Exception: writer.close() await writer.wait_closed() raise class SecretStreamSocket: def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, push_state: crypto_secretstream_xchacha20poly1305_state, pull_state: crypto_secretstream_xchacha20poly1305_state, ): self.reader = reader self.writer = writer self.push_state = push_state self.pull_state = pull_state self.buffer = bytearray() self.seen_final = False self.sent_final = False async def send_frame(self, frame: bytes) -> None: if self.sent_final: raise RuntimeError("Cannot send after final chunk was sent.") if len(frame) > MAX_FRAME_SIZE: raise ValueError(f"Frame size {len(frame)} exceeds limit {MAX_FRAME_SIZE}.") plaintext = FRAME_LENGTH.pack(len(frame)) + frame start = 0 while start < len(plaintext): used = min(CHUNK_DATA_SIZE, len(plaintext) - start) await self._send_transport_chunk( plaintext[start : start + used], tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, ) start += used await self.writer.drain() async def recv_frame(self) -> bytes | None: await self._fill_plain_buffer(1) if not self.buffer and self.seen_final: return None length_bytes = await self.read_exact(FRAME_LENGTH.size) frame_size = FRAME_LENGTH.unpack(length_bytes)[0] if frame_size > MAX_FRAME_SIZE: raise ProtocolError(f"Frame size {frame_size} exceeds limit {MAX_FRAME_SIZE}.") return await self.read_exact(frame_size) async def run_receiving_loop(self, on_frame: FrameCallback) -> None: while True: frame = await self.recv_frame() if frame is None: return on_frame(frame) async def read(self, size: int) -> bytes: await self._fill_plain_buffer(size) data = bytes(self.buffer[:size]) del self.buffer[:size] return data async def read_exact(self, size: int) -> bytes: data = await self.read(size) if len(data) != size: raise EOFError("Connection closed.") return data async def close(self, send_final: bool = True) -> None: if send_final and not self.sent_final: try: await self._send_transport_chunk( b"", tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL, ) await self.writer.drain() except (OSError, ValueError): pass finally: self.sent_final = True self.writer.close() await self.writer.wait_closed() async def _fill_plain_buffer(self, size: int) -> None: while len(self.buffer) < size and not self.seen_final: await self._read_transport_chunk() async def _read_transport_chunk(self) -> None: ciphertext = await read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE) plaintext, tag = crypto_secretstream_xchacha20poly1305_pull( self.pull_state, ciphertext ) if len(plaintext) != TRANSPORT_CHUNK_SIZE: raise ProtocolError( f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}." ) used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0] if used > CHUNK_DATA_SIZE: raise ProtocolError( f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}." ) start = CHUNK_USED_BYTES.size end = start + used self.buffer.extend(plaintext[start:end]) if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL: self.seen_final = True async def _send_transport_chunk(self, payload: bytes, tag: int) -> None: used = len(payload) if used > CHUNK_DATA_SIZE: raise ValueError(f"Chunk payload size {used} exceeds limit {CHUNK_DATA_SIZE}.") plaintext = bytearray(TRANSPORT_CHUNK_SIZE) plaintext[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used) data_start = CHUNK_USED_BYTES.size data_end = data_start + used plaintext[data_start:data_end] = payload ciphertext = crypto_secretstream_xchacha20poly1305_push( self.push_state, bytes(plaintext), tag=tag, ) self.writer.write(ciphertext)