243 lines
6.3 KiB
Python
243 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
from huggingface_hub import HfApi, hf_hub_download
|
|
from huggingface_hub.utils import enable_progress_bars
|
|
from tqdm.auto import tqdm as auto_tqdm
|
|
|
|
GGUF_SORT_FIELD = Literal[
|
|
"created_at",
|
|
"downloads",
|
|
"last_modified",
|
|
"likes",
|
|
"trending_score",
|
|
]
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GGUFRepository:
|
|
repo_id: str
|
|
downloads: int | None = None
|
|
likes: int | None = None
|
|
gated: str | bool | None = None
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GGUFFileEntry:
|
|
repo_id: str
|
|
filename: str
|
|
size_bytes: int | None = None
|
|
blob_id: str | None = None
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return Path(self.filename).name
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GGUFDownloadStatus:
|
|
repo_id: str
|
|
filename: str
|
|
revision: str | None = None
|
|
commit_hash: str | None = None
|
|
size_bytes: int | None = None
|
|
is_cached: bool | None = None
|
|
will_download: bool | None = None
|
|
|
|
|
|
class _StdoutTqdm(auto_tqdm):
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault("file", sys.stdout)
|
|
kwargs.setdefault("disable", False)
|
|
kwargs.setdefault("dynamic_ncols", True)
|
|
kwargs.setdefault("leave", True)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def _format_size(size_bytes: int | None) -> str:
|
|
if size_bytes is None:
|
|
return "unknown size"
|
|
|
|
size = float(size_bytes)
|
|
for unit in ("B", "KiB", "MiB", "GiB", "TiB"):
|
|
if size < 1024.0 or unit == "TiB":
|
|
if unit == "B":
|
|
return f"{int(size)} {unit}"
|
|
return f"{size:.1f} {unit}"
|
|
size /= 1024.0
|
|
|
|
return f"{size_bytes} B"
|
|
|
|
|
|
def _build_api(token: str | None = None) -> HfApi:
|
|
return HfApi(token=token)
|
|
|
|
|
|
def search_gguf_repositories(
|
|
query: str | None = None,
|
|
*,
|
|
author: str | None = None,
|
|
limit: int = 20,
|
|
sort: GGUF_SORT_FIELD = "downloads",
|
|
token: str | None = None,
|
|
) -> list[GGUFRepository]:
|
|
"""
|
|
Search Hugging Face model repos tagged as GGUF.
|
|
"""
|
|
api = _build_api(token)
|
|
models = api.list_models(
|
|
filter="gguf",
|
|
search=query or None,
|
|
author=author,
|
|
sort=sort,
|
|
limit=limit,
|
|
expand=["downloads", "likes", "gated"],
|
|
)
|
|
|
|
return [
|
|
GGUFRepository(
|
|
repo_id=model.id,
|
|
downloads=getattr(model, "downloads", None),
|
|
likes=getattr(model, "likes", None),
|
|
gated=getattr(model, "gated", None),
|
|
)
|
|
for model in models
|
|
]
|
|
|
|
|
|
def list_repo_gguf_files(
|
|
repo_id: str,
|
|
*,
|
|
revision: str | None = None,
|
|
token: str | None = None,
|
|
) -> list[GGUFFileEntry]:
|
|
"""
|
|
List GGUF files available inside a specific Hugging Face model repo.
|
|
"""
|
|
api = _build_api(token)
|
|
model_info = api.model_info(repo_id=repo_id, revision=revision, files_metadata=True)
|
|
siblings = getattr(model_info, "siblings", None) or ()
|
|
|
|
gguf_files = [
|
|
GGUFFileEntry(
|
|
repo_id=repo_id,
|
|
filename=sibling.rfilename,
|
|
size_bytes=getattr(sibling, "size", None),
|
|
blob_id=getattr(sibling, "blob_id", None),
|
|
)
|
|
for sibling in siblings
|
|
if sibling.rfilename.lower().endswith(".gguf")
|
|
]
|
|
gguf_files.sort(key=lambda entry: entry.filename.lower())
|
|
return gguf_files
|
|
|
|
|
|
def get_gguf_file_status(
|
|
repo_id: str,
|
|
filename: str,
|
|
*,
|
|
revision: str | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_dir: str | Path | None = None,
|
|
token: str | None = None,
|
|
) -> GGUFDownloadStatus:
|
|
"""
|
|
Ask the Hub whether a GGUF file is already cached and whether it would download.
|
|
"""
|
|
dry_run_info = hf_hub_download(
|
|
repo_id=repo_id,
|
|
filename=filename,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
local_dir=local_dir,
|
|
token=token,
|
|
dry_run=True,
|
|
)
|
|
|
|
size_bytes = getattr(dry_run_info, "file_size", None)
|
|
if size_bytes is None:
|
|
size_bytes = getattr(dry_run_info, "size", None)
|
|
|
|
return GGUFDownloadStatus(
|
|
repo_id=repo_id,
|
|
filename=filename,
|
|
revision=revision,
|
|
commit_hash=getattr(dry_run_info, "commit_hash", None),
|
|
size_bytes=size_bytes,
|
|
is_cached=getattr(dry_run_info, "is_cached", None),
|
|
will_download=getattr(dry_run_info, "will_download", None),
|
|
)
|
|
|
|
|
|
def ensure_gguf_downloaded(
|
|
repo_id: str,
|
|
filename: str,
|
|
*,
|
|
revision: str | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_dir: str | Path | None = None,
|
|
token: str | None = None,
|
|
force_download: bool = False,
|
|
local_files_only: bool = False,
|
|
) -> Path:
|
|
"""
|
|
Download a GGUF file if needed and return the local filesystem path.
|
|
|
|
With no local_dir provided, the returned path points into the Hugging Face cache.
|
|
"""
|
|
target = f"{repo_id}/{filename}"
|
|
|
|
if local_files_only:
|
|
print(f"Looking for cached GGUF: {target}")
|
|
else:
|
|
status = get_gguf_file_status(
|
|
repo_id=repo_id,
|
|
filename=filename,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
local_dir=local_dir,
|
|
token=token,
|
|
)
|
|
|
|
if force_download:
|
|
print(f"Re-downloading GGUF: {target} ({_format_size(status.size_bytes)})")
|
|
elif status.is_cached and not status.will_download:
|
|
print(f"Using cached GGUF: {target}")
|
|
elif status.will_download:
|
|
print(f"Downloading GGUF: {target} ({_format_size(status.size_bytes)})")
|
|
else:
|
|
print(f"Resolving GGUF: {target}")
|
|
|
|
enable_progress_bars()
|
|
|
|
resolved_path = hf_hub_download(
|
|
repo_id=repo_id,
|
|
filename=filename,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
local_dir=local_dir,
|
|
token=token,
|
|
force_download=force_download,
|
|
local_files_only=local_files_only,
|
|
tqdm_class=_StdoutTqdm,
|
|
)
|
|
local_path = Path(resolved_path)
|
|
print(f"GGUF ready: {local_path}")
|
|
return local_path
|
|
|
|
|
|
__all__ = [
|
|
"GGUFDownloadStatus",
|
|
"GGUFFileEntry",
|
|
"GGUFRepository",
|
|
"GGUF_SORT_FIELD",
|
|
"ensure_gguf_downloaded",
|
|
"get_gguf_file_status",
|
|
"list_repo_gguf_files",
|
|
"search_gguf_repositories",
|
|
]
|