262 lines
7.3 KiB
Python
262 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import queue
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Iterable
|
|
|
|
from api import Request, Response
|
|
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
|
|
|
|
|
SERVER_HOST = "127.0.0.1"
|
|
SERVER_PORT = 9000
|
|
SHARED_SECRET = "change-me"
|
|
|
|
|
|
PROCESS_DELAY_SECONDS = 1.5
|
|
|
|
|
|
@dataclass
|
|
class MessagePiece:
|
|
piece: str = ""
|
|
is_end: bool = False
|
|
is_cancel: bool = False
|
|
|
|
|
|
@dataclass
|
|
class PendingChatCompletionRecord:
|
|
response_queue: asyncio.Queue[MessagePiece]
|
|
was_cancelled: bool = False
|
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
|
|
|
|
|
|
|
|
def mark_cancelled(self) -> None:
|
|
with self._lock:
|
|
self.was_cancelled = True
|
|
|
|
def is_cancelled(self) -> bool:
|
|
with self._lock:
|
|
return self.was_cancelled
|
|
|
|
|
|
@dataclass
|
|
class WorkItem:
|
|
request_id: int
|
|
messages: list
|
|
record: PendingChatCompletionRecord
|
|
|
|
|
|
def extract_last_message(messages: list) -> str:
|
|
if not messages:
|
|
return ""
|
|
last = messages[-1]
|
|
if hasattr(last, "content"):
|
|
return last.content
|
|
return last.get("content", "")
|
|
|
|
|
|
def generate_uppercase_pieces(text: str) -> Iterable[str]:
|
|
words = text.split()
|
|
for word in words:
|
|
time.sleep(PROCESS_DELAY_SECONDS)
|
|
yield word.upper() + " "
|
|
|
|
|
|
def worker_loop(
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> None:
|
|
while True:
|
|
item = work_queue.get()
|
|
if item is None:
|
|
return
|
|
|
|
record = item.record
|
|
cancelled = False
|
|
try:
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
else:
|
|
text = extract_last_message(item.messages)
|
|
for piece in generate_uppercase_pieces(text):
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
|
)
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
break
|
|
|
|
if cancelled:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=True),
|
|
)
|
|
else:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=False),
|
|
)
|
|
finally:
|
|
with pending_lock:
|
|
pending.pop(item.request_id, None)
|
|
work_queue.task_done()
|
|
|
|
|
|
def build_response(request_id: int, piece: MessagePiece) -> bytes:
|
|
if piece.is_cancel:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "cancel",
|
|
"payload": {},
|
|
}
|
|
)
|
|
if piece.is_end:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "end",
|
|
"payload": {},
|
|
}
|
|
)
|
|
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "chat",
|
|
"payload": {"piece": piece.piece},
|
|
}
|
|
)
|
|
|
|
|
|
async def forward_pieces(
|
|
request_id: int,
|
|
response_queue: asyncio.Queue[MessagePiece],
|
|
transport: SecretStreamSocket,
|
|
send_lock: asyncio.Lock,
|
|
) -> None:
|
|
while True:
|
|
piece = await response_queue.get()
|
|
response_bytes = build_response(request_id, piece)
|
|
async with send_lock:
|
|
await transport.send_frame(response_bytes)
|
|
|
|
if piece.is_cancel or piece.is_end:
|
|
return
|
|
|
|
|
|
async def handle_client(
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
shared_secret: str,
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
) -> None:
|
|
transport = await wrap_connection_socket(reader, writer, shared_secret)
|
|
send_lock = asyncio.Lock()
|
|
response_tasks: set[asyncio.Task[None]] = set()
|
|
owned_request_ids: set[int] = set()
|
|
|
|
try:
|
|
while True:
|
|
frame = await transport.recv_frame()
|
|
if frame is None:
|
|
return
|
|
|
|
request = Request.parse(frame)
|
|
request_id = request.request_id
|
|
|
|
if request.kind == "chat":
|
|
print(f"[request] chat request_id={request_id}", flush=True)
|
|
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
|
|
record = PendingChatCompletionRecord(
|
|
response_queue=response_queue,
|
|
)
|
|
|
|
with pending_lock:
|
|
if request_id in pending:
|
|
raise ProtocolError(
|
|
f"Duplicate request_id {request_id} received on this connection"
|
|
)
|
|
pending[request_id] = record
|
|
|
|
owned_request_ids.add(request_id)
|
|
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
|
|
|
task = asyncio.create_task(
|
|
forward_pieces(request_id, response_queue, transport, send_lock)
|
|
)
|
|
response_tasks.add(task)
|
|
task.add_done_callback(response_tasks.discard)
|
|
elif request.kind == "cancel":
|
|
print(f"[request] cancel request_id={request_id}", flush=True)
|
|
with pending_lock:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
else:
|
|
raise ProtocolError(f"Unknown request kind {request.kind!r}")
|
|
except (ProtocolError, EOFError, asyncio.IncompleteReadError):
|
|
return
|
|
finally:
|
|
with pending_lock:
|
|
for request_id in owned_request_ids:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
|
|
for task in response_tasks:
|
|
task.cancel()
|
|
if response_tasks:
|
|
await asyncio.gather(*response_tasks, return_exceptions=True)
|
|
|
|
await transport.close()
|
|
|
|
|
|
async def run_server() -> None:
|
|
pending: Dict[int, PendingChatCompletionRecord] = {}
|
|
pending_lock = threading.Lock()
|
|
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
|
|
|
|
loop = asyncio.get_running_loop()
|
|
worker = threading.Thread(
|
|
target=worker_loop,
|
|
args=(work_queue, pending, pending_lock, loop),
|
|
daemon=True,
|
|
)
|
|
worker.start()
|
|
|
|
async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
await handle_client(
|
|
reader,
|
|
writer,
|
|
SHARED_SECRET,
|
|
work_queue,
|
|
pending,
|
|
pending_lock,
|
|
)
|
|
|
|
server = await asyncio.start_server(client_handler, SERVER_HOST, SERVER_PORT)
|
|
addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
|
|
print(f"[listening] {addr}")
|
|
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
def main() -> None:
|
|
asyncio.run(run_server())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|