355 lines
10 KiB
Python
355 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import queue
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Iterable
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from api import Request, Response
|
|
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
|
|
|
|
|
SERVER_HOST = "127.0.0.1"
|
|
SERVER_PORT = 9000
|
|
SHARED_SECRET = "change-me"
|
|
|
|
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
|
MAX_NEW_TOKENS = 256
|
|
PIECE_CHUNK_SIZE = 64
|
|
|
|
|
|
@dataclass
|
|
class MessagePiece:
|
|
piece: str = ""
|
|
is_end: bool = False
|
|
is_cancel: bool = False
|
|
|
|
|
|
@dataclass
|
|
class PendingChatCompletionRecord:
|
|
response_queue: asyncio.Queue[MessagePiece]
|
|
was_cancelled: bool = False
|
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
|
|
|
|
|
|
|
|
def mark_cancelled(self) -> None:
|
|
with self._lock:
|
|
self.was_cancelled = True
|
|
|
|
def is_cancelled(self) -> bool:
|
|
with self._lock:
|
|
return self.was_cancelled
|
|
|
|
|
|
@dataclass
|
|
class WorkItem:
|
|
request_id: int
|
|
messages: list
|
|
record: PendingChatCompletionRecord
|
|
|
|
|
|
@dataclass
|
|
class ModelBundle:
|
|
tokenizer: Any
|
|
model: Any
|
|
|
|
|
|
def fail(message: str) -> None:
|
|
print(message, flush=True)
|
|
raise SystemExit(1)
|
|
|
|
|
|
def build_context(messages: list) -> list[dict[str, str]]:
|
|
context: list[dict[str, str]] = []
|
|
for message in messages:
|
|
if hasattr(message, "role") and hasattr(message, "content"):
|
|
role = message.role
|
|
content = message.content
|
|
else:
|
|
role = message.get("role", "")
|
|
content = message.get("content", "")
|
|
context.append({"role": role, "content": content})
|
|
return context
|
|
|
|
|
|
def load_local_model() -> ModelBundle:
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
MODEL_ID,
|
|
local_files_only=True,
|
|
)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_ID,
|
|
local_files_only=True,
|
|
dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
except OSError as exc:
|
|
fail(
|
|
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
|
|
f"Run B2.py first. Original error: {exc}"
|
|
)
|
|
|
|
return ModelBundle(
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
)
|
|
|
|
|
|
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
|
|
context = build_context(messages)
|
|
inputs = bundle.tokenizer.apply_chat_template(
|
|
context,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
|
|
|
|
input_ids = inputs.get("input_ids")
|
|
input_len = int(input_ids.shape[1]) if input_ids is not None else 0
|
|
|
|
with torch.inference_mode():
|
|
output_ids = bundle.model.generate(
|
|
**inputs,
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
do_sample=False,
|
|
)
|
|
|
|
generated_ids = output_ids[0][input_len:]
|
|
text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
if not text:
|
|
return []
|
|
|
|
return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)]
|
|
|
|
|
|
def worker_loop(
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
loop: asyncio.AbstractEventLoop,
|
|
model_bundle: ModelBundle,
|
|
) -> None:
|
|
while True:
|
|
item = work_queue.get()
|
|
if item is None:
|
|
return
|
|
|
|
record = item.record
|
|
cancelled = False
|
|
try:
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
else:
|
|
for piece in generate_llm_pieces(model_bundle, item.messages):
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
|
)
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
break
|
|
|
|
if cancelled:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=True),
|
|
)
|
|
else:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=False),
|
|
)
|
|
except BaseException as exc:
|
|
print(f"[worker] generation failed: {exc}", flush=True)
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=True),
|
|
)
|
|
finally:
|
|
with pending_lock:
|
|
pending.pop(item.request_id, None)
|
|
work_queue.task_done()
|
|
|
|
|
|
def build_response(request_id: int, piece: MessagePiece) -> bytes:
|
|
if piece.is_cancel:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "cancel",
|
|
"payload": {},
|
|
}
|
|
)
|
|
if piece.is_end:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "end",
|
|
"payload": {},
|
|
}
|
|
)
|
|
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "chat",
|
|
"payload": {"piece": piece.piece},
|
|
}
|
|
)
|
|
|
|
|
|
async def forward_pieces(
|
|
request_id: int,
|
|
response_queue: asyncio.Queue[MessagePiece],
|
|
transport: SecretStreamSocket,
|
|
send_lock: asyncio.Lock,
|
|
) -> None:
|
|
while True:
|
|
piece = await response_queue.get()
|
|
response_bytes = build_response(request_id, piece)
|
|
async with send_lock:
|
|
await transport.send_frame(response_bytes)
|
|
|
|
if piece.is_cancel or piece.is_end:
|
|
return
|
|
|
|
|
|
async def handle_chat_request(
|
|
request: Request,
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
transport: SecretStreamSocket,
|
|
send_lock: asyncio.Lock,
|
|
response_tasks: set[asyncio.Task[None]],
|
|
) -> None:
|
|
request_id = request.request_id
|
|
print(f"[request] chat request_id={request_id}", flush=True)
|
|
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
|
|
record = PendingChatCompletionRecord(
|
|
response_queue=response_queue,
|
|
)
|
|
|
|
with pending_lock:
|
|
if request_id in pending:
|
|
raise ProtocolError(
|
|
f"Duplicate request_id {request_id} received on this connection"
|
|
)
|
|
pending[request_id] = record
|
|
|
|
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
|
|
|
task = asyncio.create_task(
|
|
forward_pieces(request_id, response_queue, transport, send_lock)
|
|
)
|
|
response_tasks.add(task)
|
|
task.add_done_callback(response_tasks.discard)
|
|
|
|
|
|
async def handle_client(
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
shared_secret: str,
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
) -> None:
|
|
transport = await wrap_connection_socket(reader, writer, shared_secret)
|
|
send_lock = asyncio.Lock()
|
|
response_tasks: set[asyncio.Task[None]] = set()
|
|
owned_request_ids: set[int] = set()
|
|
|
|
try:
|
|
while True:
|
|
frame = await transport.recv_frame()
|
|
if frame is None:
|
|
return
|
|
|
|
request = Request.parse(frame)
|
|
request_id = request.request_id
|
|
|
|
if request.kind == "chat":
|
|
await handle_chat_request(
|
|
request,
|
|
work_queue,
|
|
pending,
|
|
pending_lock,
|
|
transport,
|
|
send_lock,
|
|
response_tasks,
|
|
)
|
|
owned_request_ids.add(request_id)
|
|
elif request.kind == "cancel":
|
|
print(f"[request] cancel request_id={request_id}", flush=True)
|
|
with pending_lock:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
else:
|
|
raise ProtocolError(f"Unknown request kind {request.kind!r}")
|
|
except (ProtocolError, EOFError, asyncio.IncompleteReadError):
|
|
return
|
|
finally:
|
|
with pending_lock:
|
|
for request_id in owned_request_ids:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
|
|
for task in response_tasks:
|
|
task.cancel()
|
|
if response_tasks:
|
|
await asyncio.gather(*response_tasks, return_exceptions=True)
|
|
|
|
await transport.close()
|
|
|
|
|
|
async def run_server(model_bundle: ModelBundle) -> None:
|
|
pending: Dict[int, PendingChatCompletionRecord] = {}
|
|
pending_lock = threading.Lock()
|
|
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
|
|
|
|
loop = asyncio.get_running_loop()
|
|
worker = threading.Thread(
|
|
target=worker_loop,
|
|
args=(work_queue, pending, pending_lock, loop, model_bundle),
|
|
daemon=True,
|
|
)
|
|
worker.start()
|
|
|
|
async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
await handle_client(
|
|
reader,
|
|
writer,
|
|
SHARED_SECRET,
|
|
work_queue,
|
|
pending,
|
|
pending_lock,
|
|
)
|
|
|
|
server = await asyncio.start_server(client_handler, SERVER_HOST, SERVER_PORT)
|
|
addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
|
|
print(f"[listening] {addr}")
|
|
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
def main() -> None:
|
|
model_bundle = load_local_model()
|
|
print("[model] loaded", flush=True)
|
|
asyncio.run(run_server(model_bundle))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|