109 lines
2.7 KiB
Python
109 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
from threading import Thread
|
|
from typing import Any
|
|
import torch
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
TextIteratorStreamer,
|
|
)
|
|
|
|
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
|
MAX_NEW_TOKENS = 256
|
|
|
|
SYSTEM_PROMPT = (
|
|
"You are a helpful AI assistant. Answer clearly and keep the response concise."
|
|
)
|
|
USER_PROMPT = "Briefly introduce yourself and say what model you are."
|
|
|
|
|
|
def fail(message: str) -> None:
|
|
print(message, file=sys.stderr, flush=True)
|
|
raise SystemExit(1)
|
|
|
|
|
|
def build_context() -> list[dict[str, str]]:
|
|
return [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": USER_PROMPT},
|
|
]
|
|
|
|
|
|
def load_local_model() -> tuple[Any, Any, Any, Any]:
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
MODEL_ID,
|
|
local_files_only=True,
|
|
)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_ID,
|
|
local_files_only=True,
|
|
dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
except OSError as exc:
|
|
fail(
|
|
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
|
|
f"Run B2.py first. Original error: {exc}"
|
|
)
|
|
|
|
return torch, tokenizer, model, TextIteratorStreamer
|
|
|
|
|
|
def main() -> None:
|
|
torch, tokenizer, model, text_iterator_streamer_cls = load_local_model()
|
|
context = build_context()
|
|
|
|
print("[Model loaded]", flush=True)
|
|
|
|
inputs = tokenizer.apply_chat_template(
|
|
context,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()}
|
|
|
|
streamer = text_iterator_streamer_cls(
|
|
tokenizer,
|
|
skip_prompt=True,
|
|
skip_special_tokens=True,
|
|
)
|
|
|
|
generation_exception: list[BaseException] = []
|
|
|
|
def run_generation() -> None:
|
|
try:
|
|
with torch.inference_mode():
|
|
model.generate(
|
|
**inputs,
|
|
streamer=streamer,
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
do_sample=False,
|
|
)
|
|
except BaseException as exc:
|
|
generation_exception.append(exc)
|
|
|
|
generation_thread = Thread(target=run_generation, daemon=True)
|
|
generation_thread.start()
|
|
|
|
for piece in streamer:
|
|
sys.stdout.write(piece)
|
|
sys.stdout.flush()
|
|
|
|
generation_thread.join()
|
|
|
|
if generation_exception:
|
|
fail(f"Generation failed: {generation_exception[0]}")
|
|
|
|
if sys.stdout.isatty():
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|