112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import socket
|
|
from dataclasses import dataclass
|
|
from typing import BinaryIO
|
|
|
|
from nacl.bindings.crypto_secretstream import (
|
|
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
|
|
crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
|
|
crypto_secretstream_xchacha20poly1305_init_push,
|
|
crypto_secretstream_xchacha20poly1305_push,
|
|
crypto_secretstream_xchacha20poly1305_state,
|
|
)
|
|
|
|
from E import (
|
|
CHUNK_DATA_SIZE,
|
|
CHUNK_USED_BYTES,
|
|
MAX_MESSAGE_SIZE,
|
|
MESSAGE_LENGTH,
|
|
SERVER_HOST,
|
|
SERVER_PORT,
|
|
SHARED_SECRET,
|
|
TRANSPORT_CHUNK_SIZE,
|
|
IncomingChatBatch,
|
|
derive_key,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ClientConnection:
|
|
sock: socket.socket
|
|
writer: BinaryIO
|
|
state: crypto_secretstream_xchacha20poly1305_state
|
|
|
|
|
|
def connect_to_server(
|
|
host: str = SERVER_HOST,
|
|
port: int = SERVER_PORT,
|
|
shared_secret: str = SHARED_SECRET,
|
|
) -> ClientConnection:
|
|
sock = socket.create_connection((host, port))
|
|
writer = sock.makefile("wb")
|
|
state = crypto_secretstream_xchacha20poly1305_state()
|
|
header = crypto_secretstream_xchacha20poly1305_init_push(state, derive_key(shared_secret))
|
|
writer.write(header)
|
|
writer.flush()
|
|
return ClientConnection(sock=sock, writer=writer, state=state)
|
|
|
|
|
|
def send_message(connection: ClientConnection, role: str, content: str) -> None:
|
|
payload = IncomingChatBatch.build(
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
if len(payload) > MAX_MESSAGE_SIZE:
|
|
raise ValueError(
|
|
f"Application message size {len(payload)} exceeds limit {MAX_MESSAGE_SIZE}."
|
|
)
|
|
|
|
plaintext = MESSAGE_LENGTH.pack(len(payload)) + payload
|
|
|
|
start = 0
|
|
while start < len(plaintext):
|
|
used = min(CHUNK_DATA_SIZE, len(plaintext) - start)
|
|
chunk = bytearray(TRANSPORT_CHUNK_SIZE)
|
|
chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used)
|
|
|
|
data_start = CHUNK_USED_BYTES.size
|
|
data_end = data_start + used
|
|
chunk[data_start:data_end] = plaintext[start : start + used]
|
|
|
|
ciphertext = crypto_secretstream_xchacha20poly1305_push(
|
|
connection.state,
|
|
bytes(chunk),
|
|
tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
|
|
)
|
|
connection.writer.write(ciphertext)
|
|
start += used
|
|
|
|
connection.writer.flush()
|
|
|
|
|
|
def close_connection(connection: ClientConnection) -> None:
|
|
final_chunk = bytearray(TRANSPORT_CHUNK_SIZE)
|
|
final_chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(0)
|
|
ciphertext = crypto_secretstream_xchacha20poly1305_push(
|
|
connection.state,
|
|
bytes(final_chunk),
|
|
tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL,
|
|
)
|
|
try:
|
|
connection.writer.write(ciphertext)
|
|
connection.writer.flush()
|
|
finally:
|
|
connection.writer.close()
|
|
connection.sock.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
connection = connect_to_server()
|
|
try:
|
|
send_message(connection, "user", "\n".join([f"What are you doing {i}" for i in range(10000)]))
|
|
finally:
|
|
close_connection(connection)
|