235 lines
7.2 KiB
Python
235 lines
7.2 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import socket
|
|
import struct
|
|
from typing import BinaryIO, Callable, Final
|
|
|
|
from nacl.bindings.crypto_secretstream import (
|
|
crypto_secretstream_xchacha20poly1305_ABYTES,
|
|
crypto_secretstream_xchacha20poly1305_HEADERBYTES,
|
|
crypto_secretstream_xchacha20poly1305_KEYBYTES,
|
|
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
|
|
crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
|
|
crypto_secretstream_xchacha20poly1305_init_pull,
|
|
crypto_secretstream_xchacha20poly1305_init_push,
|
|
crypto_secretstream_xchacha20poly1305_pull,
|
|
crypto_secretstream_xchacha20poly1305_push,
|
|
crypto_secretstream_xchacha20poly1305_state,
|
|
)
|
|
|
|
|
|
FrameCallback = Callable[[bytes], None]
|
|
|
|
KEY_DERIVATION_PERSON: Final = b"vv-secretstream1"
|
|
MAX_FRAME_SIZE: Final = 64 * 1024 * 1024
|
|
|
|
TRANSPORT_CHUNK_SIZE: Final = 4096
|
|
CHUNK_USED_BYTES: Final = struct.Struct("!H")
|
|
CHUNK_DATA_SIZE: Final = TRANSPORT_CHUNK_SIZE - CHUNK_USED_BYTES.size
|
|
assert CHUNK_DATA_SIZE > 0
|
|
CIPHERTEXT_CHUNK_SIZE: Final = (
|
|
TRANSPORT_CHUNK_SIZE + crypto_secretstream_xchacha20poly1305_ABYTES
|
|
)
|
|
FRAME_LENGTH: Final = struct.Struct("!I")
|
|
|
|
|
|
class ProtocolError(RuntimeError):
|
|
pass
|
|
|
|
|
|
def derive_key(shared_secret: str) -> bytes:
|
|
return hashlib.blake2b(
|
|
shared_secret.encode("utf-8"),
|
|
digest_size=crypto_secretstream_xchacha20poly1305_KEYBYTES,
|
|
person=KEY_DERIVATION_PERSON,
|
|
).digest()
|
|
|
|
|
|
def read_exact(reader: BinaryIO | SecretStreamSocket, size: int) -> bytes:
|
|
data = bytearray()
|
|
|
|
while len(data) < size:
|
|
chunk = reader.read(size - len(data))
|
|
if chunk == b"":
|
|
raise EOFError("Connection closed.")
|
|
data.extend(chunk)
|
|
|
|
return bytes(data)
|
|
|
|
|
|
def wrap_connection_socket(
|
|
connection_socket: socket.socket,
|
|
shared_secret: str,
|
|
) -> SecretStreamSocket:
|
|
key = derive_key(shared_secret)
|
|
reader = connection_socket.makefile("rb")
|
|
writer = connection_socket.makefile("wb")
|
|
push_state = crypto_secretstream_xchacha20poly1305_state()
|
|
pull_state = crypto_secretstream_xchacha20poly1305_state()
|
|
|
|
try:
|
|
header = crypto_secretstream_xchacha20poly1305_init_push(push_state, key)
|
|
writer.write(header)
|
|
writer.flush()
|
|
|
|
peer_header = read_exact(reader, crypto_secretstream_xchacha20poly1305_HEADERBYTES)
|
|
crypto_secretstream_xchacha20poly1305_init_pull(pull_state, peer_header, key)
|
|
|
|
return SecretStreamSocket(
|
|
sock=connection_socket,
|
|
reader=reader,
|
|
writer=writer,
|
|
push_state=push_state,
|
|
pull_state=pull_state,
|
|
)
|
|
except Exception:
|
|
# Python sucks, exceptions suck
|
|
writer.close()
|
|
reader.close()
|
|
connection_socket.close()
|
|
raise
|
|
|
|
|
|
class SecretStreamSocket:
|
|
def __init__(
|
|
self,
|
|
sock: socket.socket,
|
|
reader: BinaryIO,
|
|
writer: BinaryIO,
|
|
push_state: crypto_secretstream_xchacha20poly1305_state,
|
|
pull_state: crypto_secretstream_xchacha20poly1305_state,
|
|
):
|
|
self.sock = sock
|
|
self.reader = reader
|
|
self.writer = writer
|
|
self.push_state = push_state
|
|
self.pull_state = pull_state
|
|
self.buffer = bytearray()
|
|
self.seen_final = False
|
|
self.sent_final = False
|
|
|
|
def __enter__(self) -> SecretStreamSocket:
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None:
|
|
self.close()
|
|
|
|
def send_frame(self, frame: bytes) -> None:
|
|
if self.sent_final:
|
|
raise RuntimeError("Cannot send after final chunk was sent.")
|
|
if len(frame) > MAX_FRAME_SIZE:
|
|
raise ValueError(f"Frame size {len(frame)} exceeds limit {MAX_FRAME_SIZE}.")
|
|
|
|
plaintext = FRAME_LENGTH.pack(len(frame)) + frame
|
|
start = 0
|
|
|
|
while start < len(plaintext):
|
|
used = min(CHUNK_DATA_SIZE, len(plaintext) - start)
|
|
self._send_transport_chunk(
|
|
plaintext[start : start + used],
|
|
tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
|
|
)
|
|
start += used
|
|
|
|
self.writer.flush()
|
|
|
|
def recv_frame(self) -> bytes | None:
|
|
if self._peek_plain(1) == b"":
|
|
return None
|
|
|
|
length_bytes = read_exact(self, FRAME_LENGTH.size)
|
|
frame_size = FRAME_LENGTH.unpack(length_bytes)[0]
|
|
|
|
if frame_size > MAX_FRAME_SIZE:
|
|
raise ProtocolError(f"Frame size {frame_size} exceeds limit {MAX_FRAME_SIZE}.")
|
|
|
|
return read_exact(self, frame_size)
|
|
|
|
def run_receiving_loop(self, on_frame: FrameCallback) -> None:
|
|
while True:
|
|
frame = self.recv_frame()
|
|
if frame is None:
|
|
return
|
|
on_frame(frame)
|
|
|
|
def read(self, size) -> bytes:
|
|
while len(self.buffer) < size and not self.seen_final:
|
|
self._read_transport_chunk()
|
|
|
|
data = bytes(self.buffer[:size])
|
|
del self.buffer[:size]
|
|
return data
|
|
|
|
def close(self, send_final: bool = True) -> None:
|
|
if send_final and not self.sent_final:
|
|
try:
|
|
self._send_transport_chunk(
|
|
b"",
|
|
tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL,
|
|
)
|
|
self.writer.flush()
|
|
except (OSError, ValueError):
|
|
pass
|
|
finally:
|
|
self.sent_final = True
|
|
|
|
try:
|
|
self.writer.close()
|
|
except (OSError, ValueError):
|
|
pass
|
|
try:
|
|
self.reader.close()
|
|
except (OSError, ValueError):
|
|
pass
|
|
self.sock.close()
|
|
|
|
def _peek_plain(self, size: int = 1) -> bytes:
|
|
while len(self.buffer) < size and not self.seen_final:
|
|
self._read_transport_chunk()
|
|
|
|
return bytes(self.buffer[:size])
|
|
|
|
def _read_transport_chunk(self) -> None:
|
|
ciphertext = read_exact(self.reader, CIPHERTEXT_CHUNK_SIZE)
|
|
plaintext, tag = crypto_secretstream_xchacha20poly1305_pull(
|
|
self.pull_state, ciphertext
|
|
)
|
|
|
|
if len(plaintext) != TRANSPORT_CHUNK_SIZE:
|
|
raise ProtocolError(
|
|
f"Expected {TRANSPORT_CHUNK_SIZE} plaintext bytes, got {len(plaintext)}."
|
|
)
|
|
|
|
used = CHUNK_USED_BYTES.unpack(plaintext[: CHUNK_USED_BYTES.size])[0]
|
|
if used > CHUNK_DATA_SIZE:
|
|
raise ProtocolError(
|
|
f"Chunk declared {used} payload bytes, limit is {CHUNK_DATA_SIZE}."
|
|
)
|
|
|
|
start = CHUNK_USED_BYTES.size
|
|
end = start + used
|
|
self.buffer.extend(plaintext[start:end])
|
|
|
|
if tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL:
|
|
self.seen_final = True
|
|
|
|
def _send_transport_chunk(self, payload: bytes, tag: int) -> None:
|
|
used = len(payload)
|
|
if used > CHUNK_DATA_SIZE:
|
|
raise ValueError(f"Chunk payload size {used} exceeds limit {CHUNK_DATA_SIZE}.")
|
|
|
|
plaintext = bytearray(TRANSPORT_CHUNK_SIZE)
|
|
plaintext[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used)
|
|
|
|
data_start = CHUNK_USED_BYTES.size
|
|
data_end = data_start + used
|
|
plaintext[data_start:data_end] = payload
|
|
|
|
ciphertext = crypto_secretstream_xchacha20poly1305_push(
|
|
self.push_state,
|
|
bytes(plaintext),
|
|
tag=tag,
|
|
)
|
|
self.writer.write(ciphertext)
|