vibevibing/secret_stream_socket.py

235 lines
7.2 KiB
Python

from __future__ import annotations
import hashlib
import socket
import struct
from typing import BinaryIO, Callable, Final
from nacl.bindings.crypto_secretstream import (
crypto_secretstream_xchacha20poly1305_ABYTES,
crypto_secretstream_xchacha20poly1305_HEADERBYTES,
crypto_secretstream_xchacha20poly1305_KEYBYTES,
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
crypto_secretstream_xchacha20poly1305_init_pull,
crypto_secretstream_xchacha20poly1305_init_push,
crypto_secretstream_xchacha20poly1305_pull,
crypto_secretstream_xchacha20poly1305_push,
crypto_secretstream_xchacha20poly1305_state,
)
FrameCallback = Callable[[bytes], None]
KEY_DERIVATION_PERSON: Final = b"vv-secretstream1"
MAX_FRAME_SIZE: Final = 64 * 1024 * 1024
TRANSPORT_CHUNK_SIZE: Final = 4096
CHUNK_USED_BYTES: Final = struct.Struct("!H")
CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size
assert CHUNK_DATA_SIZE > 0
CIPHERTEXT_CHUNK_SIZE: Final = (
TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES
)
FRAME_LENGTH: Final = struct.Struct("!I")
class ProtocolError(RuntimeError):
pass
def derive_key(shared_secret: str) -> bytes:
return hashlib.blake2b(
shared_secret.encode("utf-8"),
digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES,
person=KEY_DERIVATION_PERSON,
).digest()
def read_exact(reader: BinaryIO | SecretStreamSocket, size: int) -> bytes:
data = bytearray()
while len(data) < size:
chunk = reader.read(size - len(data))
if chunk == b"":
raise EOFError("Connection closed.")
data.extend(chunk)
return bytes(data)
def wrap_connection_socket(
connection_socket: socket.socket,
shared_secret: str,
) -> SecretStreamSocket:
key = derive_key(shared_secret)
reader = connection_socket.makefile("rb")
writer = connection_socket.makefile("wb")
push_state = crypto_secretstream_xchacha20poly1305_state()
pull_state = crypto_secretstream_xchacha20poly1305_state()
try:
header = crypto_secretstream_xchacha20poly1305_init_push(push_state, key)
writer.write(header)
writer.flush()
peer_header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES)
crypto_secretstream_xchacha20poly1305_init_pull(pull_state, peer_header, key)
return SecretStreamSocket(
sock=connection_socket,
reader=reader,
writer=writer,
push_state=push_state,
pull_state=pull_state,
)
except Exception:
# Python sucks, exceptions suck
writer.close()
reader.close()
connection_socket.close()
raise
class SecretStreamSocket:
def __init__(
self,
sock: socket.socket,
reader: BinaryIO,
writer: BinaryIO,
push_state: crypto_secretstream_xchacha20poly1305_state,
pull_state: crypto_secretstream_xchacha20poly1305_state,
):
self.sock = sock
self.reader = reader
self.writer = writer
self.push_state = push_state
self.pull_state = pull_state
self.buffer = bytearray()
self.seen_final = False
self.sent_final = False
def __enter__(self) -> SecretStreamSocket:
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
def send_frame(self, frame: bytes) -> None:
if self.sent_final:
raise RuntimeError("Cannot send after final chunk was sent.")
if len(frame) > MAX_FRAME_SIZE:
raise ValueError(f"Frame size {len(frame)} exceeds limit {MAX_FRAME_SIZE}.")
plaintext = FRAME_LENGTH.pack(len(frame)) + frame
start = 0
while start < len(plaintext):
used = min(CHUNK_DATA_SIZE, len(plaintext) - start)
self._send_transport_chunk(
plaintext[start : start + used],
tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
)
start += used
self.writer.flush()
def recv_frame(self) -> bytes | None:
if self._peek_plain(1) == b"":
return None
length_bytes = read_exact(self, FRAME_LENGTH.size)
frame_size = FRAME_LENGTH.unpack(length_bytes)[0]
if frame_size > MAX_FRAME_SIZE:
raise ProtocolError(f"Frame size {frame_size} exceeds limit {MAX_FRAME_SIZE}.")
return read_exact(self, frame_size)
def run_receiving_loop(self, on_frame: FrameCallback) -> None:
while True:
frame = self.recv_frame()
if frame is None:
return
on_frame(frame)
def read(self, size) -> bytes:
while len(self.buffer) < size and not self.seen_final:
self._read_transport_chunk()
data = bytes(self.buffer[:size])
del self.buffer[:size]
return data
def close(self, send_final: bool = True) -> None:
if send_final and not self.sent_final:
try:
self._send_transport_chunk(
b"",
tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL,
)
self.writer.flush()
except (OSError, ValueError):
pass
finally:
self.sent_final = True
try:
self.writer.close()
except (OSError, ValueError):
pass
try:
self.reader.close()
except (OSError, ValueError):
pass
self.sock.close()
def _peek_plain(self, size: int = 1) -> bytes:
while len(self.buffer) < size and not self.seen_final:
self._read_transport_chunk()
return bytes(self.buffer[:size])
def _read_transport_chunk(self) -> None:
ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE)
plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(
self.pull_state, ciphertext
)
if len(plaintext) != TRANSPORT_CHUNK_SIZE:
raise ProtocolError(
f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}."
)
used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0]
if used > CHUNK_DATA_SIZE:
raise ProtocolError(
f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}."
)
start = CHUNK_USED_BYTES.size
end = start + used
self.buffer.extend(plaintext[start:end])
if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL:
self.seen_final = True
def _send_transport_chunk(self, payload: bytes, tag: int) -> None:
used = len(payload)
if used > CHUNK_DATA_SIZE:
raise ValueError(f"Chunk payload size {used} exceeds limit {CHUNK_DATA_SIZE}.")
plaintext = bytearray(TRANSPORT_CHUNK_SIZE)
plaintext[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used)
data_start = CHUNK_USED_BYTES.size
data_end = data_start + used
plaintext[data_start:data_end] = payload
ciphertext = crypto_secretstream_xchacha20poly1305_push(
self.push_state,
bytes(plaintext),
tag=tag,
)
self.writer.write(ciphertext)