import sys from time import perf_counter from typing import Optional, Any, List from llama_cpp import Llama from gguf_inspection import * def utter(s: str): print(s, end="", flush=True) def load_llama_model(model_path): llm = Llama( model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=0, verbose=False) return llm def load_chat_llama_model(model_path): llm = Llama( model_path=model_path, n_ctx=2048, n_threads=8, n_gpu_layers=0, verbose=False, chat_format="llama-3") return llm def tokenize_prompt(llm: Llama, prompt: str) -> List[int]: return llm.tokenize(prompt.encode("utf-8")) def create_llm_token_generator(llm, prompt_tokens): return llm.generate( prompt_tokens, temp=0.7, top_k=40, top_p=0.95, repeat_penalty=1.0) def tokens_to_str(pieces: list[bytes]) -> str: return b"".join(pieces).decode("utf-8", errors="replace") def serialize_optional_float(val: Optional[float]) -> str: return f"{val:.6f}" if val is not None else "N/A" def create_chat_messages(system_prompt: str, user_prompt: str) -> List[dict[str, str]]: return [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] def load_and_generate_llm_inference(model_path: str, prompt: str) -> str: load_start_time = perf_counter() llm = load_llama_model(model_path) load_end_time = perf_counter() load_time = load_end_time - load_start_time print("[Model loaded]") #print_useful_metadata_from_llama(model_path, llm) prompt_tokens = tokenize_prompt(llm, prompt) token_stream = create_llm_token_generator(llm, prompt_tokens) max_new_tokens: int = 128 eos_token = llm.token_eos() generation_start_time = perf_counter() first_token_generation_time: Optional[float] = None pieces: list[bytes] = [] try: for i, token in enumerate(token_stream): if i == 0: first_token_generation_time = perf_counter() - generation_start_time if i >= max_new_tokens: break if token == eos_token: break piece = llm.detokenize([token]) pieces.append(piece) if (len(piece)) == 0: utter(f"<{token}>") else: utter(piece.decode("utf-8")) except KeyboardInterrupt: print("\n[stopped]") generation_end_time = perf_counter() average_token_generation_time: Optional[float] = None if len(pieces) > 0: average_token_generation_time = (generation_end_time - generation_start_time) / len(pieces) print("") print(f"Loading time: {load_time:.6f}") print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}") print(f"Average token generation time: {serialize_optional_float(average_token_generation_time)}") return tokens_to_str(pieces) def load_and_generate_chat_inference(model_path: str, system_prompt: str, user_prompt: str) -> str: load_start_time = perf_counter() llm = load_chat_llama_model(model_path) load_end_time = perf_counter() load_time = load_end_time - load_start_time print("[Model loaded]") # print_useful_metadata_from_llama(model_path, llm) messages = create_chat_messages(system_prompt, user_prompt) max_new_tokens: int = 128 generation_start_time = perf_counter() first_token_generation_time: Optional[float] = None pieces: list[str] = [] try: chunk_stream = llm.create_chat_completion( messages=messages, temperature=0.7, top_k=40, top_p=0.95, repeat_penalty=1.0, max_tokens=max_new_tokens, stream=True, ) for chunk in chunk_stream: delta = chunk["choices"][0]["delta"] piece = delta.get("content") if not piece: continue if first_token_generation_time is None: first_token_generation_time = perf_counter() - generation_start_time pieces.append(piece) sys.stdout.write(piece) sys.stdout.flush() except KeyboardInterrupt: print("\n[stopped]") generation_end_time = perf_counter() average_chunk_generation_time: Optional[float] = None if len(pieces) > 0: average_chunk_generation_time = (generation_end_time - generation_start_time) / len(pieces) print("") print(f"Loading time: {load_time:.6f}") print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}") print(f"Average streamed chunk generation time: {serialize_optional_float(average_chunk_generation_time)}") return "".join(pieces)