from __future__ import annotations import socket from dataclasses import dataclass from typing import BinaryIO from nacl.bindings.crypto_secretstream import ( crypto_secretstream_xchacha20poly1305_TAG_FINAL, crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, crypto_secretstream_xchacha20poly1305_init_push, crypto_secretstream_xchacha20poly1305_push, crypto_secretstream_xchacha20poly1305_state, ) from E import ( CHUNK_DATA_SIZE, CHUNK_USED_BYTES, MAX_MESSAGE_SIZE, MESSAGE_LENGTH, SERVER_HOST, SERVER_PORT, SHARED_SECRET, TRANSPORT_CHUNK_SIZE, IncomingChatBatch, derive_key, ) @dataclass class ClientConnection: sock: socket.socket writer: BinaryIO state: crypto_secretstream_xchacha20poly1305_state def connect_to_server( host: str = SERVER_HOST, port: int = SERVER_PORT, shared_secret: str = SHARED_SECRET, ) -> ClientConnection: sock = socket.create_connection((host, port)) writer = sock.makefile("wb") state = crypto_secretstream_xchacha20poly1305_state() header = crypto_secretstream_xchacha20poly1305_init_push(state, derive_key(shared_secret)) writer.write(header) writer.flush() return ClientConnection(sock=sock, writer=writer, state=state) def send_message(connection: ClientConnection, role: str, content: str) -> None: payload = IncomingChatBatch.build( { "messages": [ { "role": role, "content": content, } ] } ) if len(payload) > MAX_MESSAGE_SIZE: raise ValueError( f"Application message size {len(payload)} exceeds limit {MAX_MESSAGE_SIZE}." ) plaintext = MESSAGE_LENGTH.pack(len(payload)) + payload start = 0 while start < len(plaintext): used = min(CHUNK_DATA_SIZE, len(plaintext) - start) chunk = bytearray(TRANSPORT_CHUNK_SIZE) chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used) data_start = CHUNK_USED_BYTES.size data_end = data_start + used chunk[data_start:data_end] = plaintext[start : start + used] ciphertext = crypto_secretstream_xchacha20poly1305_push( connection.state, bytes(chunk), tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE, ) connection.writer.write(ciphertext) start += used connection.writer.flush() def close_connection(connection: ClientConnection) -> None: final_chunk = bytearray(TRANSPORT_CHUNK_SIZE) final_chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(0) ciphertext = crypto_secretstream_xchacha20poly1305_push( connection.state, bytes(final_chunk), tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL, ) try: connection.writer.write(ciphertext) connection.writer.flush() finally: connection.writer.close() connection.sock.close() if __name__ == "__main__": connection = connect_to_server() try: send_message(connection, "user", "\n".join([f"What are you doing {i}" for i in range(10000)])) finally: close_connection(connection)