from __future__ import annotations import asyncio import queue import threading from dataclasses import dataclass, field from typing import Any, Dict, Iterable import torch from transformers import AutoModelForCausalLM, AutoTokenizer from api import Request, Response from config import Config, read_config from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket MAX_NEW_TOKENS = 256 @dataclass class MessagePiece: piece: str = "" is_end: bool = False is_cancel: bool = False @dataclass class PendingChatCompletionRecord: response_queue: asyncio.Queue[MessagePiece] was_cancelled: bool = False _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) def mark_cancelled(self) -> None: with self._lock: self.was_cancelled = True def is_cancelled(self) -> bool: with self._lock: return self.was_cancelled @dataclass class WorkItem: request_id: int messages: list record: PendingChatCompletionRecord @dataclass class ModelBundle: tokenizer: Any model: Any def fail(message: str) -> None: print(message, flush=True) raise SystemExit(1) def build_context(messages: list) -> list[dict[str, str]]: context: list[dict[str, str]] = [] for message in messages: if hasattr(message, "role") and hasattr(message, "content"): role = message.role content = message.content else: role = message.get("role", "") content = message.get("content", "") context.append({"role": role, "content": content}) return context def load_local_model(model_id: str) -> ModelBundle: try: tokenizer = AutoTokenizer.from_pretrained( model_id, local_files_only=True, ) model = AutoModelForCausalLM.from_pretrained( model_id, local_files_only=True, dtype=torch.bfloat16, device_map="auto", ) except OSError as exc: fail( f"Model {model_id!r} is not fully available in the Hugging Face cache. " f"Run B2.py first. Original error: {exc}" ) return ModelBundle( tokenizer=tokenizer, model=model, ) def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]: context = build_context(messages) inputs = bundle.tokenizer.apply_chat_template( context, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()} input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") past_key_values = None eos_token_id = bundle.model.config.eos_token_id if eos_token_id is None: eos_token_id = bundle.tokenizer.eos_token_id if eos_token_id is None: eos_token_ids = set() elif isinstance(eos_token_id, (list, tuple, set)): eos_token_ids = set(int(x) for x in eos_token_id) else: eos_token_ids = {int(eos_token_id)} for _ in range(MAX_NEW_TOKENS): with torch.inference_mode(): outputs = bundle.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(logits, dim=-1) token_id = int(next_token_id.item()) print("[debug] get argmaxed") if token_id in eos_token_ids: break token_text = bundle.tokenizer.decode([token_id], skip_special_tokens=True) if token_text: yield token_text past_key_values = outputs.past_key_values input_ids = next_token_id.unsqueeze(1) if attention_mask is not None: attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=1, ) def worker_loop( work_queue: queue.Queue[WorkItem | None], pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, loop: asyncio.AbstractEventLoop, model_bundle: ModelBundle, ) -> None: while True: item = work_queue.get() if item is None: return record = item.record cancelled = False try: if record.is_cancelled(): cancelled = True else: print("[debug] starting llm inference for a chat completion request") for piece in generate_llm_pieces(model_bundle, item.messages): loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(piece=piece, is_end=False, is_cancel=False), ) if record.is_cancelled(): cancelled = True break if cancelled: loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=True), ) else: loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=False), ) except BaseException as exc: print(f"[worker] generation failed: {exc}", flush=True) loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=True), ) finally: with pending_lock: pending.pop(item.request_id, None) work_queue.task_done() def build_response(request_id: int, piece: MessagePiece) -> bytes: if piece.is_cancel: return Response.build( { "request_id": request_id, "kind": "cancel", "payload": {}, } ) if piece.is_end: return Response.build( { "request_id": request_id, "kind": "end", "payload": {}, } ) return Response.build( { "request_id": request_id, "kind": "chat", "payload": {"piece": piece.piece}, } ) async def forward_pieces( request_id: int, response_queue: asyncio.Queue[MessagePiece], transport: SecretStreamSocket, send_lock: asyncio.Lock, ) -> None: while True: piece = await response_queue.get() response_bytes = build_response(request_id, piece) async with send_lock: await transport.send_frame(response_bytes) if piece.is_cancel or piece.is_end: return async def handle_chat_request( request: Request, work_queue: queue.Queue[WorkItem | None], pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, transport: SecretStreamSocket, send_lock: asyncio.Lock, response_tasks: set[asyncio.Task[None]], owned_request_ids: set[int], ) -> None: request_id = request.request_id print(f"[request] chat request_id={request_id}", flush=True) response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() record = PendingChatCompletionRecord( response_queue=response_queue, ) with pending_lock: if request_id in pending: raise ProtocolError( f"Duplicate request_id {request_id} received on this connection" ) pending[request_id] = record owned_request_ids.add(request_id) work_queue.put(WorkItem(request_id, list(request.payload.messages), record)) task = asyncio.create_task( forward_pieces(request_id, response_queue, transport, send_lock) ) response_tasks.add(task) task.add_done_callback(response_tasks.discard) async def handle_client( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, shared_secret: str, work_queue: queue.Queue[WorkItem | None], pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, ) -> None: transport = await wrap_connection_socket(reader, writer, shared_secret) send_lock = asyncio.Lock() response_tasks: set[asyncio.Task[None]] = set() owned_request_ids: set[int] = set() try: while True: frame = await transport.recv_frame() if frame is None: return request = Request.parse(frame) request_id = request.request_id if request.kind == "chat": await handle_chat_request( request, work_queue, pending, pending_lock, transport, send_lock, response_tasks, ) owned_request_ids.add(request_id) elif request.kind == "cancel": print(f"[request] cancel request_id={request_id}", flush=True) with pending_lock: record = pending.get(request_id) if record is not None: record.mark_cancelled() else: raise ProtocolError(f"Unknown request kind {request.kind!r}") except (ProtocolError, EOFError, asyncio.IncompleteReadError): return finally: with pending_lock: for request_id in owned_request_ids: record = pending.get(request_id) if record is not None: record.mark_cancelled() for task in response_tasks: task.cancel() if response_tasks: await asyncio.gather(*response_tasks, return_exceptions=True) await transport.close() async def run_server(model_bundle: ModelBundle, config: Config) -> None: pending: Dict[int, PendingChatCompletionRecord] = {} pending_lock = threading.Lock() work_queue: queue.Queue[WorkItem | None] = queue.Queue() loop = asyncio.get_running_loop() worker = threading.Thread( target=worker_loop, args=(work_queue, pending, pending_lock, loop, model_bundle), daemon=True, ) worker.start() async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: await handle_client( reader, writer, config.secret, work_queue, pending, pending_lock, ) server = await asyncio.start_server( client_handler, config.listening_addr, config.listening_port, ) addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or []) print(f"[listening] {addr}") async with server: await server.serve_forever() def main() -> None: config = read_config() model_bundle = load_local_model(config.model_id) print(f"[model] loaded {config.model_id}", flush=True) asyncio.run(run_server(model_bundle, config)) if __name__ == "__main__": main()