109 lines
2.7 KiB
Python

from __future__ import annotations
import sys
from threading import Thread
from typing import Any
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
MODEL_ID = "zai-org/GLM-4.7-Flash"
MAX_NEW_TOKENS = 256
SYSTEM_PROMPT = (
"You are a helpful AI assistant. Answer clearly and keep the response concise."
)
USER_PROMPT = "Briefly introduce yourself and say what model you are."
def fail(message: str) -> None:
print(message, file=sys.stderr, flush=True)
raise SystemExit(1)
def build_context() -> list[dict[str, str]]:
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT},
]
def load_local_model() -> tuple[Any, Any, Any, Any]:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
local_files_only=True,
dtype=torch.bfloat16,
device_map="auto",
)
except OSError as exc:
fail(
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
f"Run B2.py first. Original error: {exc}"
)
return torch, tokenizer, model, TextIteratorStreamer
def main() -> None:
torch, tokenizer, model, text_iterator_streamer_cls = load_local_model()
context = build_context()
print("[Model loaded]", flush=True)
inputs = tokenizer.apply_chat_template(
context,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()}
streamer = text_iterator_streamer_cls(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_exception: list[BaseException] = []
def run_generation() -> None:
try:
with torch.inference_mode():
model.generate(
**inputs,
streamer=streamer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
except BaseException as exc:
generation_exception.append(exc)
generation_thread = Thread(target=run_generation, daemon=True)
generation_thread.start()
for piece in streamer:
sys.stdout.write(piece)
sys.stdout.flush()
generation_thread.join()
if generation_exception:
fail(f"Generation failed: {generation_exception[0]}")
if sys.stdout.isatty():
sys.stdout.write("\n")
sys.stdout.flush()
if __name__ == "__main__":
main()