from __future__ import annotations import sys from threading import Thread from typing import Any import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) MODEL_ID = "zai-org/GLM-4.7-Flash" MAX_NEW_TOKENS = 256 SYSTEM_PROMPT = ( "You are a helpful AI assistant. Answer clearly and keep the response concise." ) USER_PROMPT = "Briefly introduce yourself and say what model you are." def fail(message: str) -> None: print(message, file=sys.stderr, flush=True) raise SystemExit(1) def build_context() -> list[dict[str, str]]: return [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": USER_PROMPT}, ] def load_local_model() -> tuple[Any, Any, Any, Any]: try: tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, local_files_only=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, local_files_only=True, torch_dtype=torch.bfloat16, device_map="auto", ) except OSError as exc: fail( f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. " f"Run B2.py first. Original error: {exc}" ) except Exception as exc: fail(f"Failed to load {MODEL_ID!r}: {exc}") return torch, tokenizer, model, TextIteratorStreamer def main() -> None: torch, tokenizer, model, text_iterator_streamer_cls = load_local_model() context = build_context() print("[Model loaded]", flush=True) inputs = tokenizer.apply_chat_template( context, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = {name: tensor.to(model.device) for name, tensor in inputs.items()} streamer = text_iterator_streamer_cls( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_exception: list[BaseException] = [] def run_generation() -> None: try: with torch.inference_mode(): model.generate( **inputs, streamer=streamer, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, ) except BaseException as exc: generation_exception.append(exc) generation_thread = Thread(target=run_generation, daemon=True) generation_thread.start() for piece in streamer: sys.stdout.write(piece) sys.stdout.flush() generation_thread.join() if generation_exception: fail(f"Generation failed: {generation_exception[0]}") if sys.stdout.isatty(): sys.stdout.write("\n") sys.stdout.flush() if __name__ == "__main__": main()