264 lines
7.4 KiB
Python

from __future__ import annotations
import asyncio
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Dict, Iterable
from api import Request, Response
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
SERVER_HOST = "127.0.0.1"
SERVER_PORT = 9000
SHARED_SECRET = "change-me"
PROCESS_DELAY_SECONDS = 0.4
@dataclass
class MessagePiece:
piece: str = ""
is_end: bool = False
is_cancel: bool = False
@dataclass
class PendingChatCompletionRecord:
response_queue: asyncio.Queue[MessagePiece]
was_cancelled: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def mark_cancelled(self) -> None:
with self._lock:
self.was_cancelled = True
def is_cancelled(self) -> bool:
with self._lock:
return self.was_cancelled
@dataclass
class WorkItem:
request_id: int
messages: list
record: PendingChatCompletionRecord
def extract_last_message(messages: list) -> str:
if not messages:
return ""
last = messages[-1]
if hasattr(last, "content"):
return last.content
return last.get("content", "")
def generate_uppercase_pieces(text: str) -> Iterable[str]:
words = text.split()
for word in words:
time.sleep(PROCESS_DELAY_SECONDS)
yield word.upper() + " "
def worker_loop(
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
loop: asyncio.AbstractEventLoop,
) -> None:
while True:
item = work_queue.get()
if item is None:
return
record = item.record
cancelled = False
try:
if record.is_cancelled():
cancelled = True
else:
text = extract_last_message(item.messages)
for piece in generate_uppercase_pieces(text):
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(piece=piece, is_end=False, is_cancel=False),
)
print("[debug] got a new piece")
if record.is_cancelled():
print("[debug] record was cancelled")
cancelled = True
break
if cancelled:
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
else:
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=False),
)
finally:
with pending_lock:
pending.pop(item.request_id, None)
work_queue.task_done()
def build_response(request_id: int, piece: MessagePiece) -> bytes:
if piece.is_cancel:
return Response.build(
{
"request_id": request_id,
"kind": "cancel",
"payload": {},
}
)
if piece.is_end:
return Response.build(
{
"request_id": request_id,
"kind": "end",
"payload": {},
}
)
return Response.build(
{
"request_id": request_id,
"kind": "chat",
"payload": {"piece": piece.piece},
}
)
async def forward_pieces(
request_id: int,
response_queue: asyncio.Queue[MessagePiece],
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
) -> None:
while True:
piece = await response_queue.get()
response_bytes = build_response(request_id, piece)
async with send_lock:
await transport.send_frame(response_bytes)
if piece.is_cancel or piece.is_end:
return
async def handle_client(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
shared_secret: str,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
) -> None:
transport = await wrap_connection_socket(reader, writer, shared_secret)
send_lock = asyncio.Lock()
response_tasks: set[asyncio.Task[None]] = set()
owned_request_ids: set[int] = set()
try:
while True:
frame = await transport.recv_frame()
if frame is None:
return
request = Request.parse(frame)
request_id = request.request_id
if request.kind == "chat":
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
owned_request_ids.add(request_id)
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
elif request.kind == "cancel":
print(f"[request] cancel request_id={request_id}", flush=True)
with pending_lock:
record = pending.get(request_id)
if record is not None:
record.mark_cancelled()
else:
raise ProtocolError(f"Unknown request kind {request.kind!r}")
except (ProtocolError, EOFError, asyncio.IncompleteReadError):
return
finally:
with pending_lock:
for request_id in owned_request_ids:
record = pending.get(request_id)
if record is not None:
record.mark_cancelled()
for task in response_tasks:
task.cancel()
if response_tasks:
await asyncio.gather(*response_tasks, return_exceptions=True)
await transport.close()
async def run_server() -> None:
pending: Dict[int, PendingChatCompletionRecord] = {}
pending_lock = threading.Lock()
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
loop = asyncio.get_running_loop()
worker = threading.Thread(
target=worker_loop,
args=(work_queue, pending, pending_lock, loop),
daemon=True,
)
worker.start()
async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
await handle_client(
reader,
writer,
SHARED_SECRET,
work_queue,
pending,
pending_lock,
)
server = await asyncio.start_server(client_handler, SERVER_HOST, SERVER_PORT)
addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
print(f"[listening] {addr}")
async with server:
await server.serve_forever()
def main() -> None:
asyncio.run(run_server())
if __name__ == "__main__":
main()