vibevibing/E.py

179 lines
5.7 KiB
Python

from __future__ import annotations
import hashlib
import socket
import struct
from typing import BinaryIO, Final
from construct import ConstructError, Int32ul, PascalString, PrefixedArray, Struct as CStruct
from nacl.bindings.crypto_secretstream import (
crypto_secretstream_xchacha20poly1305_ABYTES,
crypto_secretstream_xchacha20poly1305_HEADERBYTES,
crypto_secretstream_xchacha20poly1305_KEYBYTES,
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
crypto_secretstream_xchacha20poly1305_init_pull,
crypto_secretstream_xchacha20poly1305_pull,
crypto_secretstream_xchacha20poly1305_state,
)
KEY_DERIVATION_PERSON: Final = b"vv-secretstream1"
SERVER_HOST: Final = "127.0.0.1"
SERVER_PORT: Final = 9000
SHARED_SECRET: Final = "change-me"
MAX_MESSAGE_SIZE: Final = 64 * 1024 * 1024
TRANSPORT_CHUNK_SIZE: Final = 4096
CHUNK_USED_BYTES: Final = struct.Struct("!H")
CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size
CIPHERTEXT_CHUNK_SIZE: Final = (
TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES
)
MESSAGE_LENGTH: Final = struct.Struct("!I")
class ProtocolError(RuntimeError):
pass
MessageInChat = CStruct(
"role" / PascalString(Int32ul, "utf8"),
"content" / PascalString(Int32ul, "utf8"),
)
IncomingChatBatch = CStruct(
"messages" / PrefixedArray(Int32ul, MessageInChat),
)
def derive_key(shared_secret: str) -> bytes:
return hashlib.blake2b(
shared_secret.encode("utf-8"),
digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES,
person=KEY_DERIVATION_PERSON,
).digest()
def read_exact(reader: BinaryIO | SecretStreamReader, size: int) -> bytes:
data = bytearray()
while len(data) < size:
chunk = reader.read(size - len(data))
if chunk == b"":
raise EOFError("Connection closed.")
data.extend(chunk)
return bytes(data)
class SecretStreamReader:
def __init__(self, reader: BinaryIO, key: bytes):
self.reader = reader
self.state = crypto_secretstream_xchacha20poly1305_state()
self.buffer = bytearray()
self.seen_final = False
header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES)
crypto_secretstream_xchacha20poly1305_init_pull(self.state, header, key)
def peek(self, size: int = 1) -> bytes:
while len(self.buffer) < size and not self.seen_final:
self._read_chunk()
return bytes(self.buffer[:size])
def read(self, size: int = -1) -> bytes:
while len(self.buffer) < size and not self.seen_final:
self._read_chunk()
data = bytes(self.buffer[:size])
del self.buffer[:size]
return data
def _read_chunk(self) -> None:
ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE)
plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(self.state, ciphertext)
if len(plaintext) != TRANSPORT_CHUNK_SIZE:
raise ProtocolError(
f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}."
)
used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0]
if used > CHUNK_DATA_SIZE:
raise ProtocolError(
f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}."
)
start = CHUNK_USED_BYTES.size
end = start + used
self.buffer.extend(plaintext[start:end])
if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL:
self.seen_final = True
def recv_batch(stream: SecretStreamReader) -> object:
size_bytes = read_exact(stream, MESSAGE_LENGTH.size)
message_size = MESSAGE_LENGTH.unpack(size_bytes)[0]
if message_size > MAX_MESSAGE_SIZE:
raise ProtocolError(
f"Application message size {message_size} exceeds limit {MAX_MESSAGE_SIZE}."
)
payload = read_exact(stream, message_size)
return IncomingChatBatch.parse(payload)
def print_batch(peer: str, batch: object) -> None:
print(f"[packet] {peer} sent {len(batch.messages)} message(s)", flush=True)
for index, message in enumerate(batch.messages):
print(
f" [{index}] role={message.role!r} content={message.content!r}",
flush=True,
)
def handle_client(client_sock: socket.socket, address: tuple[str, int], key: bytes) -> None:
peer = f"{address[0]}:{address[1]}"
with client_sock:
with client_sock.makefile("rb") as raw_reader:
stream = SecretStreamReader(raw_reader, key)
print(f"[connected] {peer}", flush=True)
while True:
if stream.peek(1) == b"":
print(f"[disconnected] {peer}", flush=True)
return
batch = recv_batch(stream)
print_batch(peer, batch)
def run_server(host: str, port: int, shared_secret: str) -> None:
key = derive_key(shared_secret)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind((host, port))
server_sock.listen()
print(f"[listening] {host}:{port}", flush=True)
while True:
client_sock, address = server_sock.accept()
try:
handle_client(client_sock, address, key)
except (ConstructError, ProtocolError) as exc:
print(f"[protocol error] {address[0]}:{address[1]}: {exc}", flush=True)
except Exception as exc:
print(f"[error] {address[0]}:{address[1]}: {exc}", flush=True)
if __name__ == "__main__":
try:
run_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET)
except KeyboardInterrupt:
print("\n[stopped]", flush=True)