collection_service/dedicated_ai_server/secret_stream_socket.py

213 lines
6.9 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import struct
from typing import Callable, Final
from nacl.bindings.crypto_secretstream import (
crypto_secretstream_xchacha20poly1305_ABYTES,
crypto_secretstream_xchacha20poly1305_HEADERBYTES,
crypto_secretstream_xchacha20poly1305_KEYBYTES,
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
crypto_secretstream_xchacha20poly1305_init_pull,
crypto_secretstream_xchacha20poly1305_init_push,
crypto_secretstream_xchacha20poly1305_pull,
crypto_secretstream_xchacha20poly1305_push,
crypto_secretstream_xchacha20poly1305_state,
)
FrameCallback = Callable[[bytes], None]
KEY_DERIVATION_PERSON: Final = b"vv-secretstream1"
MAX_FRAME_SIZE: Final = 64 * 1024 * 1024
TRANSPORT_CHUNK_SIZE: Final = 4096
CHUNK_USED_BYTES: Final = struct.Struct("!H")
CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size
assert CHUNK_DATA_SIZE > 0
CIPHERTEXT_CHUNK_SIZE: Final = (
TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES
)
FRAME_LENGTH: Final = struct.Struct("!I")
class ProtocolError(RuntimeError):
pass
def derive_key(shared_secret: str) -> bytes:
return hashlib.blake2b(
shared_secret.encode("utf-8"),
digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES,
person=KEY_DERIVATION_PERSON,
).digest()
async def read_exact(reader: asyncio.StreamReader, size: int) -> bytes:
data = await reader.readexactly(size)
if data == b"":
raise EOFError("Connection closed.")
return data
async def wrap_connection_socket(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
shared_secret: str,
) -> SecretStreamSocket:
key = derive_key(shared_secret)
push_state = crypto_secretstream_xchacha20poly1305_state()
pull_state = crypto_secretstream_xchacha20poly1305_state()
try:
header = crypto_secretstream_xchacha20poly1305_init_push(push_state, key)
writer.write(header)
await writer.drain()
peer_header = await read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES)
crypto_secretstream_xchacha20poly1305_init_pull(pull_state, peer_header, key)
return SecretStreamSocket(
reader=reader,
writer=writer,
push_state=push_state,
pull_state=pull_state,
)
except Exception:
writer.close()
await writer.wait_closed()
raise
class SecretStreamSocket:
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
push_state: crypto_secretstream_xchacha20poly1305_state,
pull_state: crypto_secretstream_xchacha20poly1305_state,
):
self.reader = reader
self.writer = writer
self.push_state = push_state
self.pull_state = pull_state
self.buffer = bytearray()
self.seen_final = False
self.sent_final = False
async def send_frame(self, frame: bytes) -> None:
if self.sent_final:
raise RuntimeError("Cannot send after final chunk was sent.")
if len(frame) > MAX_FRAME_SIZE:
raise ValueError(f"Frame size {len(frame)} exceeds limit {MAX_FRAME_SIZE}.")
plaintext = FRAME_LENGTH.pack(len(frame)) + frame
start = 0
while start < len(plaintext):
used = min(CHUNK_DATA_SIZE, len(plaintext) - start)
await self._send_transport_chunk(
plaintext[start : start + used],
tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
)
start += used
await self.writer.drain()
async def recv_frame(self) -> bytes | None:
await self._fill_plain_buffer(1)
if not self.buffer and self.seen_final:
return None
length_bytes = await self.read_exact(FRAME_LENGTH.size)
frame_size = FRAME_LENGTH.unpack(length_bytes)[0]
if frame_size > MAX_FRAME_SIZE:
raise ProtocolError(f"Frame size {frame_size} exceeds limit {MAX_FRAME_SIZE}.")
return await self.read_exact(frame_size)
async def run_receiving_loop(self, on_frame: FrameCallback) -> None:
while True:
frame = await self.recv_frame()
if frame is None:
return
on_frame(frame)
async def read(self, size: int) -> bytes:
await self._fill_plain_buffer(size)
data = bytes(self.buffer[:size])
del self.buffer[:size]
return data
async def read_exact(self, size: int) -> bytes:
data = await self.read(size)
if len(data) != size:
raise EOFError("Connection closed.")
return data
async def close(self, send_final: bool = True) -> None:
if send_final and not self.sent_final:
try:
await self._send_transport_chunk(
b"",
tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL,
)
await self.writer.drain()
except (OSError, ValueError):
pass
finally:
self.sent_final = True
self.writer.close()
await self.writer.wait_closed()
async def _fill_plain_buffer(self, size: int) -> None:
while len(self.buffer) < size and not self.seen_final:
await self._read_transport_chunk()
async def _read_transport_chunk(self) -> None:
ciphertext = await read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE)
plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(
self.pull_state, ciphertext
)
if len(plaintext) != TRANSPORT_CHUNK_SIZE:
raise ProtocolError(
f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}."
)
used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0]
if used > CHUNK_DATA_SIZE:
raise ProtocolError(
f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}."
)
start = CHUNK_USED_BYTES.size
end = start + used
self.buffer.extend(plaintext[start:end])
if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL:
self.seen_final = True
async def _send_transport_chunk(self, payload: bytes, tag: int) -> None:
used = len(payload)
if used > CHUNK_DATA_SIZE:
raise ValueError(f"Chunk payload size {used} exceeds limit {CHUNK_DATA_SIZE}.")
plaintext = bytearray(TRANSPORT_CHUNK_SIZE)
plaintext[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used)
data_start = CHUNK_USED_BYTES.size
data_end = data_start + used
plaintext[data_start:data_end] = payload
ciphertext = crypto_secretstream_xchacha20poly1305_push(
self.push_state,
bytes(plaintext),
tag=tag,
)
self.writer.write(ciphertext)