157 lines
4.8 KiB
Python
157 lines
4.8 KiB
Python
import sys
|
|
from time import perf_counter
|
|
from typing import Optional, Any, List
|
|
|
|
from llama_cpp import Llama
|
|
|
|
from gguf_inspection import *
|
|
|
|
def utter(s: str):
|
|
print(s, end="", flush=True)
|
|
|
|
def load_llama_model(model_path):
|
|
llm = Llama(
|
|
model_path=model_path,
|
|
n_ctx=2048,
|
|
n_threads=8,
|
|
n_gpu_layers=0,
|
|
verbose=False)
|
|
return llm
|
|
|
|
def load_chat_llama_model(model_path):
|
|
llm = Llama(
|
|
model_path=model_path,
|
|
n_ctx=2048,
|
|
n_threads=8,
|
|
n_gpu_layers=0,
|
|
verbose=False,
|
|
chat_format="llama-3")
|
|
return llm
|
|
|
|
def tokenize_prompt(llm: Llama, prompt: str) -> List[int]:
|
|
return llm.tokenize(prompt.encode("utf-8"))
|
|
|
|
def create_llm_token_generator(llm, prompt_tokens):
|
|
return llm.generate(
|
|
prompt_tokens,
|
|
temp=0.7,
|
|
top_k=40,
|
|
top_p=0.95,
|
|
repeat_penalty=1.0)
|
|
|
|
def tokens_to_str(pieces: list[bytes]) -> str:
|
|
return b"".join(pieces).decode("utf-8", errors="replace")
|
|
|
|
def serialize_optional_float(val: Optional[float]) -> str:
|
|
return f"{val:.6f}" if val is not None else "N/A"
|
|
|
|
def create_chat_messages(system_prompt: str, user_prompt: str) -> List[dict[str, str]]:
|
|
return [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
def load_and_generate_llm_inference(model_path: str, prompt: str) -> str:
|
|
load_start_time = perf_counter()
|
|
llm = load_llama_model(model_path)
|
|
load_end_time = perf_counter()
|
|
load_time = load_end_time - load_start_time
|
|
|
|
print("[Model loaded]")
|
|
|
|
#print_useful_metadata_from_llama(model_path, llm)
|
|
|
|
prompt_tokens = tokenize_prompt(llm, prompt)
|
|
|
|
token_stream = create_llm_token_generator(llm, prompt_tokens)
|
|
max_new_tokens: int = 128
|
|
eos_token = llm.token_eos()
|
|
|
|
generation_start_time = perf_counter()
|
|
first_token_generation_time: Optional[float] = None
|
|
pieces: list[bytes] = []
|
|
try:
|
|
for i, token in enumerate(token_stream):
|
|
if i == 0:
|
|
first_token_generation_time = perf_counter() - generation_start_time
|
|
if i >= max_new_tokens:
|
|
break
|
|
if token == eos_token:
|
|
break
|
|
|
|
piece = llm.detokenize([token])
|
|
pieces.append(piece)
|
|
|
|
if (len(piece)) == 0:
|
|
utter(f"<{token}>")
|
|
else:
|
|
utter(piece.decode("utf-8"))
|
|
except KeyboardInterrupt:
|
|
print("\n[stopped]")
|
|
generation_end_time = perf_counter()
|
|
|
|
average_token_generation_time: Optional[float] = None
|
|
|
|
if len(pieces) > 0:
|
|
average_token_generation_time = (generation_end_time - generation_start_time) / len(pieces)
|
|
print("")
|
|
print(f"Loading time: {load_time:.6f}")
|
|
print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}")
|
|
print(f"Average token generation time: {serialize_optional_float(average_token_generation_time)}")
|
|
|
|
return tokens_to_str(pieces)
|
|
|
|
def load_and_generate_chat_inference(model_path: str, system_prompt: str, user_prompt: str) -> str:
|
|
load_start_time = perf_counter()
|
|
llm = load_chat_llama_model(model_path)
|
|
load_end_time = perf_counter()
|
|
load_time = load_end_time - load_start_time
|
|
|
|
print("[Model loaded]")
|
|
|
|
# print_useful_metadata_from_llama(model_path, llm)
|
|
|
|
messages = create_chat_messages(system_prompt, user_prompt)
|
|
max_new_tokens: int = 128
|
|
|
|
generation_start_time = perf_counter()
|
|
first_token_generation_time: Optional[float] = None
|
|
pieces: list[str] = []
|
|
try:
|
|
chunk_stream = llm.create_chat_completion(
|
|
messages=messages,
|
|
temperature=0.7,
|
|
top_k=40,
|
|
top_p=0.95,
|
|
repeat_penalty=1.0,
|
|
max_tokens=max_new_tokens,
|
|
stream=True,
|
|
)
|
|
for chunk in chunk_stream:
|
|
delta = chunk["choices"][0]["delta"]
|
|
piece = delta.get("content")
|
|
|
|
if not piece:
|
|
continue
|
|
|
|
if first_token_generation_time is None:
|
|
first_token_generation_time = perf_counter() - generation_start_time
|
|
|
|
pieces.append(piece)
|
|
sys.stdout.write(piece)
|
|
sys.stdout.flush()
|
|
except KeyboardInterrupt:
|
|
print("\n[stopped]")
|
|
generation_end_time = perf_counter()
|
|
|
|
average_chunk_generation_time: Optional[float] = None
|
|
if len(pieces) > 0:
|
|
average_chunk_generation_time = (generation_end_time - generation_start_time) / len(pieces)
|
|
|
|
print("")
|
|
print(f"Loading time: {load_time:.6f}")
|
|
print(f"First token generation time: {serialize_optional_float(first_token_generation_time)}")
|
|
print(f"Average streamed chunk generation time: {serialize_optional_float(average_chunk_generation_time)}")
|
|
|
|
return "".join(pieces)
|