383 lines
11 KiB
Python
383 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import queue
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Iterable
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from api import Request, Response
|
|
from config import Config, read_config
|
|
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
|
|
|
|
|
MAX_NEW_TOKENS = 256
|
|
|
|
|
|
@dataclass
|
|
class MessagePiece:
|
|
piece: str = ""
|
|
is_end: bool = False
|
|
is_cancel: bool = False
|
|
|
|
|
|
@dataclass
|
|
class PendingChatCompletionRecord:
|
|
response_queue: asyncio.Queue[MessagePiece]
|
|
was_cancelled: bool = False
|
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
|
|
def mark_cancelled(self) -> None:
|
|
with self._lock:
|
|
self.was_cancelled = True
|
|
|
|
def is_cancelled(self) -> bool:
|
|
with self._lock:
|
|
return self.was_cancelled
|
|
|
|
|
|
@dataclass
|
|
class WorkItem:
|
|
request_id: int
|
|
messages: list
|
|
record: PendingChatCompletionRecord
|
|
|
|
|
|
|
|
@dataclass
|
|
class ModelBundle:
|
|
tokenizer: Any
|
|
model: Any
|
|
|
|
|
|
def fail(message: str) -> None:
|
|
print(message, flush=True)
|
|
raise SystemExit(1)
|
|
|
|
|
|
def build_context(messages: list) -> list[dict[str, str]]:
|
|
context: list[dict[str, str]] = []
|
|
for message in messages:
|
|
if hasattr(message, "role") and hasattr(message, "content"):
|
|
role = message.role
|
|
content = message.content
|
|
else:
|
|
role = message.get("role", "")
|
|
content = message.get("content", "")
|
|
context.append({"role": role, "content": content})
|
|
return context
|
|
|
|
|
|
def load_local_model(model_id: str) -> ModelBundle:
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
local_files_only=True,
|
|
)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
local_files_only=True,
|
|
dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
except OSError as exc:
|
|
fail(
|
|
f"Model {model_id!r} is not fully available in the Hugging Face cache. "
|
|
f"Run B2.py first. Original error: {exc}"
|
|
)
|
|
|
|
return ModelBundle(
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
)
|
|
|
|
|
|
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
|
|
context = build_context(messages)
|
|
inputs = bundle.tokenizer.apply_chat_template(
|
|
context,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
|
|
|
|
input_ids = inputs.get("input_ids")
|
|
attention_mask = inputs.get("attention_mask")
|
|
past_key_values = None
|
|
|
|
eos_token_id = bundle.model.config.eos_token_id
|
|
if eos_token_id is None:
|
|
eos_token_id = bundle.tokenizer.eos_token_id
|
|
if eos_token_id is None:
|
|
eos_token_ids = set()
|
|
elif isinstance(eos_token_id, (list, tuple, set)):
|
|
eos_token_ids = set(int(x) for x in eos_token_id)
|
|
else:
|
|
eos_token_ids = {int(eos_token_id)}
|
|
|
|
for _ in range(MAX_NEW_TOKENS):
|
|
with torch.inference_mode():
|
|
outputs = bundle.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=True,
|
|
)
|
|
|
|
logits = outputs.logits[:, -1, :]
|
|
next_token_id = torch.argmax(logits, dim=-1)
|
|
token_id = int(next_token_id.item())
|
|
|
|
print("[debug] get argmaxed")
|
|
|
|
if token_id in eos_token_ids:
|
|
break
|
|
|
|
token_text = bundle.tokenizer.decode([token_id], skip_special_tokens=True)
|
|
if token_text:
|
|
yield token_text
|
|
|
|
past_key_values = outputs.past_key_values
|
|
input_ids = next_token_id.unsqueeze(1)
|
|
if attention_mask is not None:
|
|
attention_mask = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
|
|
dim=1,
|
|
)
|
|
|
|
|
|
def worker_loop(
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
loop: asyncio.AbstractEventLoop,
|
|
model_bundle: ModelBundle,
|
|
) -> None:
|
|
while True:
|
|
item = work_queue.get()
|
|
if item is None:
|
|
return
|
|
|
|
record = item.record
|
|
cancelled = False
|
|
try:
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
else:
|
|
print("[debug] starting llm inference for a chat completion request")
|
|
for piece in generate_llm_pieces(model_bundle, item.messages):
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
|
)
|
|
if record.is_cancelled():
|
|
cancelled = True
|
|
break
|
|
|
|
if cancelled:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=True),
|
|
)
|
|
else:
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=False),
|
|
)
|
|
except BaseException as exc:
|
|
print(f"[worker] generation failed: {exc}", flush=True)
|
|
loop.call_soon_threadsafe(
|
|
record.response_queue.put_nowait,
|
|
MessagePiece(is_end=True, is_cancel=True),
|
|
)
|
|
finally:
|
|
with pending_lock:
|
|
pending.pop(item.request_id, None)
|
|
work_queue.task_done()
|
|
|
|
|
|
def build_response(request_id: int, piece: MessagePiece) -> bytes:
|
|
if piece.is_cancel:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "cancel",
|
|
"payload": {},
|
|
}
|
|
)
|
|
if piece.is_end:
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "end",
|
|
"payload": {},
|
|
}
|
|
)
|
|
|
|
return Response.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "chat",
|
|
"payload": {"piece": piece.piece},
|
|
}
|
|
)
|
|
|
|
|
|
async def forward_pieces(
|
|
request_id: int,
|
|
response_queue: asyncio.Queue[MessagePiece],
|
|
transport: SecretStreamSocket,
|
|
send_lock: asyncio.Lock,
|
|
) -> None:
|
|
while True:
|
|
piece = await response_queue.get()
|
|
response_bytes = build_response(request_id, piece)
|
|
async with send_lock:
|
|
await transport.send_frame(response_bytes)
|
|
|
|
if piece.is_cancel or piece.is_end:
|
|
return
|
|
|
|
|
|
async def handle_chat_request(
|
|
request: Request,
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
transport: SecretStreamSocket,
|
|
send_lock: asyncio.Lock,
|
|
response_tasks: set[asyncio.Task[None]],
|
|
owned_request_ids: set[int],
|
|
) -> None:
|
|
request_id = request.request_id
|
|
print(f"[request] chat request_id={request_id}", flush=True)
|
|
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
|
|
record = PendingChatCompletionRecord(
|
|
response_queue=response_queue,
|
|
)
|
|
|
|
with pending_lock:
|
|
if request_id in pending:
|
|
raise ProtocolError(
|
|
f"Duplicate request_id {request_id} received on this connection"
|
|
)
|
|
pending[request_id] = record
|
|
owned_request_ids.add(request_id)
|
|
|
|
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
|
|
|
task = asyncio.create_task(
|
|
forward_pieces(request_id, response_queue, transport, send_lock)
|
|
)
|
|
response_tasks.add(task)
|
|
task.add_done_callback(response_tasks.discard)
|
|
|
|
|
|
async def handle_client(
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
shared_secret: str,
|
|
work_queue: queue.Queue[WorkItem | None],
|
|
pending: Dict[int, PendingChatCompletionRecord],
|
|
pending_lock: threading.Lock,
|
|
) -> None:
|
|
transport = await wrap_connection_socket(reader, writer, shared_secret)
|
|
send_lock = asyncio.Lock()
|
|
response_tasks: set[asyncio.Task[None]] = set()
|
|
owned_request_ids: set[int] = set()
|
|
|
|
try:
|
|
while True:
|
|
frame = await transport.recv_frame()
|
|
if frame is None:
|
|
return
|
|
|
|
request = Request.parse(frame)
|
|
request_id = request.request_id
|
|
|
|
if request.kind == "chat":
|
|
await handle_chat_request(
|
|
request,
|
|
work_queue,
|
|
pending,
|
|
pending_lock,
|
|
transport,
|
|
send_lock,
|
|
response_tasks,
|
|
)
|
|
owned_request_ids.add(request_id)
|
|
elif request.kind == "cancel":
|
|
print(f"[request] cancel request_id={request_id}", flush=True)
|
|
with pending_lock:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
else:
|
|
raise ProtocolError(f"Unknown request kind {request.kind!r}")
|
|
except (ProtocolError, EOFError, asyncio.IncompleteReadError):
|
|
return
|
|
finally:
|
|
with pending_lock:
|
|
for request_id in owned_request_ids:
|
|
record = pending.get(request_id)
|
|
if record is not None:
|
|
record.mark_cancelled()
|
|
|
|
for task in response_tasks:
|
|
task.cancel()
|
|
if response_tasks:
|
|
await asyncio.gather(*response_tasks, return_exceptions=True)
|
|
|
|
await transport.close()
|
|
|
|
|
|
async def run_server(model_bundle: ModelBundle, config: Config) -> None:
|
|
pending: Dict[int, PendingChatCompletionRecord] = {}
|
|
pending_lock = threading.Lock()
|
|
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
|
|
|
|
loop = asyncio.get_running_loop()
|
|
worker = threading.Thread(
|
|
target=worker_loop,
|
|
args=(work_queue, pending, pending_lock, loop, model_bundle),
|
|
daemon=True,
|
|
)
|
|
worker.start()
|
|
|
|
async def client_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
await handle_client(
|
|
reader,
|
|
writer,
|
|
config.secret,
|
|
work_queue,
|
|
pending,
|
|
pending_lock,
|
|
)
|
|
|
|
server = await asyncio.start_server(
|
|
client_handler,
|
|
config.listening_addr,
|
|
config.listening_port,
|
|
)
|
|
addr = ", ".join(str(sock.getsockname()) for sock in server.sockets or [])
|
|
print(f"[listening] {addr}")
|
|
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
def main() -> None:
|
|
config = read_config()
|
|
model_bundle = load_local_model(config.model_id)
|
|
print(f"[model] loaded {config.model_id}", flush=True)
|
|
asyncio.run(run_server(model_bundle, config))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|