vibevibing/F.py

112 lines
3.2 KiB
Python

from __future__ import annotations
import socket
from dataclasses import dataclass
from typing import BinaryIO
from nacl.bindings.crypto_secretstream import (
crypto_secretstream_xchacha20poly1305_TAG_FINAL,
crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
crypto_secretstream_xchacha20poly1305_init_push,
crypto_secretstream_xchacha20poly1305_push,
crypto_secretstream_xchacha20poly1305_state,
)
from E import (
CHUNK_DATA_SIZE,
CHUNK_USED_BYTES,
MAX_MESSAGE_SIZE,
MESSAGE_LENGTH,
SERVER_HOST,
SERVER_PORT,
SHARED_SECRET,
TRANSPORT_CHUNK_SIZE,
IncomingChatBatch,
derive_key,
)
@dataclass
class ClientConnection:
sock: socket.socket
writer: BinaryIO
state: crypto_secretstream_xchacha20poly1305_state
def connect_to_server(
host: str = SERVER_HOST,
port: int = SERVER_PORT,
shared_secret: str = SHARED_SECRET,
) -> ClientConnection:
sock = socket.create_connection((host, port))
writer = sock.makefile("wb")
state = crypto_secretstream_xchacha20poly1305_state()
header = crypto_secretstream_xchacha20poly1305_init_push(state, derive_key(shared_secret))
writer.write(header)
writer.flush()
return ClientConnection(sock=sock, writer=writer, state=state)
def send_message(connection: ClientConnection, role: str, content: str) -> None:
payload = IncomingChatBatch.build(
{
"messages": [
{
"role": role,
"content": content,
}
]
}
)
if len(payload) > MAX_MESSAGE_SIZE:
raise ValueError(
f"Application message size {len(payload)} exceeds limit {MAX_MESSAGE_SIZE}."
)
plaintext = MESSAGE_LENGTH.pack(len(payload)) + payload
start = 0
while start < len(plaintext):
used = min(CHUNK_DATA_SIZE, len(plaintext) - start)
chunk = bytearray(TRANSPORT_CHUNK_SIZE)
chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(used)
data_start = CHUNK_USED_BYTES.size
data_end = data_start + used
chunk[data_start:data_end] = plaintext[start : start + used]
ciphertext = crypto_secretstream_xchacha20poly1305_push(
connection.state,
bytes(chunk),
tag=crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
)
connection.writer.write(ciphertext)
start += used
connection.writer.flush()
def close_connection(connection: ClientConnection) -> None:
final_chunk = bytearray(TRANSPORT_CHUNK_SIZE)
final_chunk[: CHUNK_USED_BYTES.size] = CHUNK_USED_BYTES.pack(0)
ciphertext = crypto_secretstream_xchacha20poly1305_push(
connection.state,
bytes(final_chunk),
tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL,
)
try:
connection.writer.write(ciphertext)
connection.writer.flush()
finally:
connection.writer.close()
connection.sock.close()
if __name__ == "__main__":
connection = connect_to_server()
try:
send_message(connection, "user", "\n".join([f"What are you doing {i}" for i in range(10000)]))
finally:
close_connection(connection)