179 lines
5.7 KiB
Python
179 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import socket
|
|
import struct
|
|
from typing import BinaryIO, Final
|
|
|
|
from construct import ConstructError, Int32ul, PascalString, PrefixedArray, Struct as CStruct
|
|
from nacl.bindings.crypto_secretstream import (
|
|
crypto_secretstream_xchacha20poly1305_ABYTES,
|
|
crypto_secretstream_xchacha20poly1305_HEADERBYTES,
|
|
crypto_secretstream_xchacha20poly1305_KEYBYTES,
|
|
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
|
|
crypto_secretstream_xchacha20poly1305_init_pull,
|
|
crypto_secretstream_xchacha20poly1305_pull,
|
|
crypto_secretstream_xchacha20poly1305_state,
|
|
)
|
|
|
|
|
|
KEY_DERIVATION_PERSON: Final = b"vv-secretstream1"
|
|
SERVER_HOST: Final = "127.0.0.1"
|
|
SERVER_PORT: Final = 9000
|
|
SHARED_SECRET: Final = "change-me"
|
|
MAX_MESSAGE_SIZE: Final = 64 * 1024 * 1024
|
|
|
|
TRANSPORT_CHUNK_SIZE: Final = 4096
|
|
CHUNK_USED_BYTES: Final = struct.Struct("!H")
|
|
CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size
|
|
CIPHERTEXT_CHUNK_SIZE: Final = (
|
|
TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES
|
|
)
|
|
MESSAGE_LENGTH: Final = struct.Struct("!I")
|
|
|
|
|
|
class ProtocolError(RuntimeError):
|
|
pass
|
|
|
|
|
|
MessageInChat = CStruct(
|
|
"role" / PascalString(Int32ul, "utf8"),
|
|
"content" / PascalString(Int32ul, "utf8"),
|
|
)
|
|
|
|
IncomingChatBatch = CStruct(
|
|
"messages" / PrefixedArray(Int32ul, MessageInChat),
|
|
)
|
|
|
|
|
|
def derive_key(shared_secret: str) -> bytes:
|
|
return hashlib.blake2b(
|
|
shared_secret.encode("utf-8"),
|
|
digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES,
|
|
person=KEY_DERIVATION_PERSON,
|
|
).digest()
|
|
|
|
|
|
def read_exact(reader: BinaryIO | SecretStreamReader, size: int) -> bytes:
|
|
data = bytearray()
|
|
|
|
while len(data) < size:
|
|
chunk = reader.read(size - len(data))
|
|
if chunk == b"":
|
|
raise EOFError("Connection closed.")
|
|
data.extend(chunk)
|
|
|
|
return bytes(data)
|
|
|
|
|
|
class SecretStreamReader:
|
|
def __init__(self, reader: BinaryIO, key: bytes):
|
|
self.reader = reader
|
|
self.state = crypto_secretstream_xchacha20poly1305_state()
|
|
self.buffer = bytearray()
|
|
self.seen_final = False
|
|
|
|
header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES)
|
|
crypto_secretstream_xchacha20poly1305_init_pull(self.state, header, key)
|
|
|
|
def peek(self, size: int = 1) -> bytes:
|
|
while len(self.buffer) < size and not self.seen_final:
|
|
self._read_chunk()
|
|
|
|
return bytes(self.buffer[:size])
|
|
|
|
def read(self, size: int = -1) -> bytes:
|
|
while len(self.buffer) < size and not self.seen_final:
|
|
self._read_chunk()
|
|
|
|
data = bytes(self.buffer[:size])
|
|
del self.buffer[:size]
|
|
return data
|
|
|
|
def _read_chunk(self) -> None:
|
|
ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE)
|
|
plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(self.state, ciphertext)
|
|
|
|
if len(plaintext) != TRANSPORT_CHUNK_SIZE:
|
|
raise ProtocolError(
|
|
f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}."
|
|
)
|
|
|
|
used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0]
|
|
if used > CHUNK_DATA_SIZE:
|
|
raise ProtocolError(
|
|
f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}."
|
|
)
|
|
|
|
start = CHUNK_USED_BYTES.size
|
|
end = start + used
|
|
self.buffer.extend(plaintext[start:end])
|
|
|
|
if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL:
|
|
self.seen_final = True
|
|
|
|
|
|
def recv_batch(stream: SecretStreamReader) -> object:
|
|
size_bytes = read_exact(stream, MESSAGE_LENGTH.size)
|
|
message_size = MESSAGE_LENGTH.unpack(size_bytes)[0]
|
|
|
|
if message_size > MAX_MESSAGE_SIZE:
|
|
raise ProtocolError(
|
|
f"Application message size {message_size} exceeds limit {MAX_MESSAGE_SIZE}."
|
|
)
|
|
|
|
payload = read_exact(stream, message_size)
|
|
return IncomingChatBatch.parse(payload)
|
|
|
|
|
|
def print_batch(peer: str, batch: object) -> None:
|
|
print(f"[packet] {peer} sent {len(batch.messages)} message(s)", flush=True)
|
|
for index, message in enumerate(batch.messages):
|
|
print(
|
|
f" [{index}] role={message.role!r} content={message.content!r}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
def handle_client(client_sock: socket.socket, address: tuple[str, int], key: bytes) -> None:
|
|
peer = f"{address[0]}:{address[1]}"
|
|
|
|
with client_sock:
|
|
with client_sock.makefile("rb") as raw_reader:
|
|
stream = SecretStreamReader(raw_reader, key)
|
|
print(f"[connected] {peer}", flush=True)
|
|
|
|
while True:
|
|
if stream.peek(1) == b"":
|
|
print(f"[disconnected] {peer}", flush=True)
|
|
return
|
|
|
|
batch = recv_batch(stream)
|
|
print_batch(peer, batch)
|
|
|
|
|
|
def run_server(host: str, port: int, shared_secret: str) -> None:
|
|
key = derive_key(shared_secret)
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
|
|
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
server_sock.bind((host, port))
|
|
server_sock.listen()
|
|
|
|
print(f"[listening] {host}:{port}", flush=True)
|
|
while True:
|
|
client_sock, address = server_sock.accept()
|
|
try:
|
|
handle_client(client_sock, address, key)
|
|
except (ConstructError, ProtocolError) as exc:
|
|
print(f"[protocol error] {address[0]}:{address[1]}: {exc}", flush=True)
|
|
except Exception as exc:
|
|
print(f"[error] {address[0]}:{address[1]}: {exc}", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
run_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET)
|
|
except KeyboardInterrupt:
|
|
print("\n[stopped]", flush=True)
|