from __future__ import annotations import asyncio import queue import threading import time from dataclasses import dataclass, field from typing import Dict, Iterable from api import Request, Response from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket SERVER_HOST = "127.0.0.1" SERVER_PORT = 9000 SHARED_SECRET = "change-me" PROCESS_DELAY_SECONDS = 1.5 @dataclass class MessagePiece: piece: str = "" is_end: bool = False is_cancel: bool = False @dataclass class PendingChatCompletionRecord: response_queue: asyncio.Queue[MessagePiece] was_cancelled: bool = False _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) def mark_cancelled(self) -> None: with self._lock: self.was_cancelled = True def is_cancelled(self) -> bool: with self._lock: return self.was_cancelled @dataclass class WorkItem: request_id: int messages: list record: PendingChatCompletionRecord def extract_last_message(messages: list) -> str: if not messages: return "" last = messages[-1] if hasattr(last, "content"): return last.content return last.get("content", "") def generate_uppercase_pieces(text: str) -> Iterable[str]: words = text.split() for word in words: time.sleep(PROCESS_DELAY_SECONDS) yield word.upper() + " " def worker_loop( work_queue: queue.Queue[WorkItem | None], pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, loop: asyncio.AbstractEventLoop, ) -> None: while True: item = work_queue.get() if item is None: return record = item.record cancelled = False try: if record.is_cancelled(): cancelled = True else: text = extract_last_message(item.messages) for piece in generate_uppercase_pieces(text): loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(piece=piece, is_end=False, is_cancel=False), ) if record.is_cancelled(): cancelled = True break if cancelled: loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=True), ) else: loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=False), ) finally: with pending_lock: pending.pop(item.request_id, None) work_queue.task_done() def build_response(request_id: int, piece: MessagePiece) -> bytes: if piece.is_cancel: return Response.build( { "request_id": request_id, "kind": "cancel", "payload": {}, } ) if piece.is_end: return Response.build( { "request_id": request_id, "kind": "end", "payload": {}, } ) return Response.build( { "request_id": request_id, "kind": "chat", "payload": {"piece": piece.piece}, } ) async def forward_pieces( request_id: int, response_queue: asyncio.Queue[MessagePiece], transport: SecretStreamSocket, send_lock: asyncio.Lock, ) -> None: while True: piece = await response_queue.get() response_bytes = build_response(request_id, piece) async with send_lock: await transport.send_frame(response_bytes) if piece.is_cancel or piece.is_end: return async def handle_client( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, shared_secret: str, work_queue: queue.Queue[WorkItem | None], pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, ) -> None: transport = await wrap_connection_socket(reader, writer, shared_secret) send_lock = asyncio.Lock() response_tasks: set[asyncio.Task[None]] = set() owned_request_ids: set[int] = set() try: while True: frame = await transport.recv_frame() if frame is None: return request = Request.parse(frame) request_id = request.request_id if request.kind == "chat": print(f"[request] chat request_id={request_id}", flush=True) response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() record = PendingChatCompletionRecord( response_queue=response_queue, ) with pending_lock: if request_id in pending: raise ProtocolError( f"Duplicate request_id {request_id} received on this connection" ) pending[request_id] = record owned_request_ids.add(request_id) work_queue.put(WorkItem(request_id, list(request.payload.messages), record)) task = asyncio.create_task( forward_pieces(request_id, response_queue, transport, send_lock) ) response_tasks.add(task) task.add_done_callback(response_tasks.discard) elif request.kind == "cancel": print(f"[request] cancel request_id={request_id}", flush=True) with pending_lock: record = pending.get(request_id) if record is not None: record.mark_cancelled() else: raise ProtocolError(f"Unknown request kind {request.kind!r}") except (ProtocolError, EOFError, asyncio.IncompleteReadError): return finally: with pending_lock: for request_id in owned_request_ids: record = pending.get(request_id) if record is not None: record.mark_cancelled() for task in response_tasks: task.cancel() if response_tasks: await asyncio.gather(*response_tasks, return_exceptions=True) await transport.close() async def run_server() -> None: pending: Dict[int, PendingChatCompletionRecord] = {} pending_lock = threading.Lock() work_queue: queue.Queue[WorkItem | None] = queue.Queue() loop = asyncio.get_running_loop() worker = threading.Thread( target=worker_loop, args=(work_queue, pending, pending_lock, loop), daemon=True, ) worker.start() async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: await handle_client( reader, writer, SHARED_SECRET, work_queue, pending, pending_lock, ) server = await asyncio.start_server(client_handler, SERVER_HOST, SERVER_PORT) addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or []) print(f"[listening] {addr}") async with server: await server.serve_forever() def main() -> None: asyncio.run(run_server()) if __name__ == "__main__": main()