383 lines
11 KiB
Python

from __future__ import annotations
import asyncio
import queue
import threading
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from api import Request, Response
from config import Config, read_config
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
MAX_NEW_TOKENS = 256
@dataclass
class MessagePiece:
piece: str = ""
is_end: bool = False
is_cancel: bool = False
@dataclass
class PendingChatCompletionRecord:
response_queue: asyncio.Queue[MessagePiece]
was_cancelled: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def mark_cancelled(self) -> None:
with self._lock:
self.was_cancelled = True
def is_cancelled(self) -> bool:
with self._lock:
return self.was_cancelled
@dataclass
class WorkItem:
request_id: int
messages: list
record: PendingChatCompletionRecord
@dataclass
class ModelBundle:
tokenizer: Any
model: Any
def fail(message: str) -> None:
print(message, flush=True)
raise SystemExit(1)
def build_context(messages: list) -> list[dict[str, str]]:
context: list[dict[str, str]] = []
for message in messages:
if hasattr(message, "role") and hasattr(message, "content"):
role = message.role
content = message.content
else:
role = message.get("role", "")
content = message.get("content", "")
context.append({"role": role, "content": content})
return context
def load_local_model(model_id: str) -> ModelBundle:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
local_files_only=True,
dtype=torch.bfloat16,
device_map="auto",
)
except OSError as exc:
fail(
f"Model {model_id!r} is not fully available in the Hugging Face cache. "
f"Run B2.py first. Original error: {exc}"
)
return ModelBundle(
tokenizer=tokenizer,
model=model,
)
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
context = build_context(messages)
inputs = bundle.tokenizer.apply_chat_template(
context,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
past_key_values = None
eos_token_id = bundle.model.config.eos_token_id
if eos_token_id is None:
eos_token_id = bundle.tokenizer.eos_token_id
if eos_token_id is None:
eos_token_ids = set()
elif isinstance(eos_token_id, (list, tuple, set)):
eos_token_ids = set(int(x) for x in eos_token_id)
else:
eos_token_ids = {int(eos_token_id)}
for _ in range(MAX_NEW_TOKENS):
with torch.inference_mode():
outputs = bundle.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(logits, dim=-1)
token_id = int(next_token_id.item())
print("[debug] get argmaxed")
if token_id in eos_token_ids:
break
token_text = bundle.tokenizer.decode([token_id], skip_special_tokens=True)
if token_text:
yield token_text
past_key_values = outputs.past_key_values
input_ids = next_token_id.unsqueeze(1)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=1,
)
def worker_loop(
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
loop: asyncio.AbstractEventLoop,
model_bundle: ModelBundle,
) -> None:
while True:
item = work_queue.get()
if item is None:
return
record = item.record
cancelled = False
try:
if record.is_cancelled():
cancelled = True
else:
print("[debug] starting llm inference for a chat completion request")
for piece in generate_llm_pieces(model_bundle, item.messages):
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(piece=piece, is_end=False, is_cancel=False),
)
if record.is_cancelled():
cancelled = True
break
if cancelled:
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
else:
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=False),
)
except BaseException as exc:
print(f"[worker] generation failed: {exc}", flush=True)
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
finally:
with pending_lock:
pending.pop(item.request_id, None)
work_queue.task_done()
def build_response(request_id: int, piece: MessagePiece) -> bytes:
if piece.is_cancel:
return Response.build(
{
"request_id": request_id,
"kind": "cancel",
"payload": {},
}
)
if piece.is_end:
return Response.build(
{
"request_id": request_id,
"kind": "end",
"payload": {},
}
)
return Response.build(
{
"request_id": request_id,
"kind": "chat",
"payload": {"piece": piece.piece},
}
)
async def forward_pieces(
request_id: int,
response_queue: asyncio.Queue[MessagePiece],
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
) -> None:
while True:
piece = await response_queue.get()
response_bytes = build_response(request_id, piece)
async with send_lock:
await transport.send_frame(response_bytes)
if piece.is_cancel or piece.is_end:
return
async def handle_chat_request(
request: Request,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
response_tasks: set[asyncio.Task[None]],
owned_request_ids: set[int],
) -> None:
request_id = request.request_id
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
owned_request_ids.add(request_id)
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
async def handle_client(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
shared_secret: str,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
) -> None:
transport = await wrap_connection_socket(reader, writer, shared_secret)
send_lock = asyncio.Lock()
response_tasks: set[asyncio.Task[None]] = set()
owned_request_ids: set[int] = set()
try:
while True:
frame = await transport.recv_frame()
if frame is None:
return
request = Request.parse(frame)
request_id = request.request_id
if request.kind == "chat":
await handle_chat_request(
request,
work_queue,
pending,
pending_lock,
transport,
send_lock,
response_tasks,
)
owned_request_ids.add(request_id)
elif request.kind == "cancel":
print(f"[request] cancel request_id={request_id}", flush=True)
with pending_lock:
record = pending.get(request_id)
if record is not None:
record.mark_cancelled()
else:
raise ProtocolError(f"Unknown request kind {request.kind!r}")
except (ProtocolError, EOFError, asyncio.IncompleteReadError):
return
finally:
with pending_lock:
for request_id in owned_request_ids:
record = pending.get(request_id)
if record is not None:
record.mark_cancelled()
for task in response_tasks:
task.cancel()
if response_tasks:
await asyncio.gather(*response_tasks, return_exceptions=True)
await transport.close()
async def run_server(model_bundle: ModelBundle, config: Config) -> None:
pending: Dict[int, PendingChatCompletionRecord] = {}
pending_lock = threading.Lock()
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
loop = asyncio.get_running_loop()
worker = threading.Thread(
target=worker_loop,
args=(work_queue, pending, pending_lock, loop, model_bundle),
daemon=True,
)
worker.start()
async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
await handle_client(
reader,
writer,
config.secret,
work_queue,
pending,
pending_lock,
)
server = await asyncio.start_server(
client_handler,
config.listening_addr,
config.listening_port,
)
addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
print(f"[listening] {addr}")
async with server:
await server.serve_forever()
def main() -> None:
config = read_config()
model_bundle = load_local_model(config.model_id)
print(f"[model] loaded {config.model_id}", flush=True)
asyncio.run(run_server(model_bundle, config))
if __name__ == "__main__":
main()