vibevibing/llm_inference.py

157 lines
4.8 KiB
Python

import sys
from time import perf_counter
from typing import Optional, Any, List
from llama_cpp import Llama
from gguf_inspection import *
def utter(s: str):
print(s, end="", flush=True)
def load_llama_model(model_path):
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=8,
n_gpu_layers=0,
verbose=False)
return llm
def load_chat_llama_model(model_path):
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=8,
n_gpu_layers=0,
verbose=False,
chat_format="llama-3")
return llm
def tokenize_prompt(llm: Llama, prompt: str) -> List[int]:
return llm.tokenize(prompt.encode("utf-8"))
def create_llm_token_generator(llm, prompt_tokens):
return llm.generate(
prompt_tokens,
temp=0.7,
top_k=40,
top_p=0.95,
repeat_penalty=1.0)
def tokens_to_str(pieces: list[bytes]) -> str:
return b"".join(pieces).decode("utf-8", errors="replace")
def serialize_optional_float(val: Optional[float]) -> str:
return f"{val:.6f}" if val is not None else "N/A"
def create_chat_messages(system_prompt: str, user_prompt: str) -> List[dict[str, str]]:
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def load_and_generate_llm_inference(model_path: str, prompt: str) -> str:
load_start_time = perf_counter()
llm = load_llama_model(model_path)
load_end_time = perf_counter()
load_time = load_end_time - load_start_time
print("[Model loaded]")
#print_useful_metadata_from_llama(model_path, llm)
prompt_tokens = tokenize_prompt(llm, prompt)
token_stream = create_llm_token_generator(llm, prompt_tokens)
max_new_tokens: int = 128
eos_token = llm.token_eos()
generation_start_time = perf_counter()
first_token_generation_time: Optional[float] = None
pieces: list[bytes] = []
try:
for i, token in enumerate(token_stream):
if i == 0:
first_token_generation_time = perf_counter() - generation_start_time
if i >= max_new_tokens:
break
if token == eos_token:
break
piece = llm.detokenize([token])
pieces.append(piece)
if (len(piece)) == 0:
utter(f"<{token}>")
else:
utter(piece.decode("utf-8"))
except KeyboardInterrupt:
print("\n[stopped]")
generation_end_time = perf_counter()
average_token_generation_time: Optional[float] = None
if len(pieces) > 0:
average_token_generation_time = (generation_end_time - generation_start_time) / len(pieces)
print("")
print(f"Loading time: {load_time:.6f}")
print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}")
print(f"Average token generation time: {serialize_optional_float(average_token_generation_time)}")
return tokens_to_str(pieces)
def load_and_generate_chat_inference(model_path: str, system_prompt: str, user_prompt: str) -> str:
load_start_time = perf_counter()
llm = load_chat_llama_model(model_path)
load_end_time = perf_counter()
load_time = load_end_time - load_start_time
print("[Model loaded]")
# print_useful_metadata_from_llama(model_path, llm)
messages = create_chat_messages(system_prompt, user_prompt)
max_new_tokens: int = 128
generation_start_time = perf_counter()
first_token_generation_time: Optional[float] = None
pieces: list[str] = []
try:
chunk_stream = llm.create_chat_completion(
messages=messages,
temperature=0.7,
top_k=40,
top_p=0.95,
repeat_penalty=1.0,
max_tokens=max_new_tokens,
stream=True,
)
for chunk in chunk_stream:
delta = chunk["choices"][0]["delta"]
piece = delta.get("content")
if not piece:
continue
if first_token_generation_time is None:
first_token_generation_time = perf_counter() - generation_start_time
pieces.append(piece)
sys.stdout.write(piece)
sys.stdout.flush()
except KeyboardInterrupt:
print("\n[stopped]")
generation_end_time = perf_counter()
average_chunk_generation_time: Optional[float] = None
if len(pieces) > 0:
average_chunk_generation_time = (generation_end_time - generation_start_time) / len(pieces)
print("")
print(f"Loading time: {load_time:.6f}")
print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}")
print(f"Average streamed chunk generation time: {serialize_optional_float(average_chunk_generation_time)}")
return "".join(pieces)