Skip to content

API Reference

All public Python APIs in assgen are documented here, auto-generated from source-code docstrings using mkdocstrings. Every function, class, and attribute shown below includes its type signature, parameter descriptions, and return/raises information derived directly from the code — no separate doc-maintenance required.


Catalog

Job-type → HuggingFace model ID mapping. Users can override entries in ~/.config/assgen/models.yaml.

catalog

Catalog loader — merges the built-in catalog.yaml with any user overrides.

The catalog maps every game-dev job type (e.g. visual.model.create) to a HuggingFace model ID and associated metadata. The built-in defaults live in catalog.yaml alongside this module; users can override any entry by adding the same key to ~/.config/assgen/models.yaml under the catalog: key.

Example

from assgen.catalog import load_catalog, get_model_for_job catalog = load_catalog() entry = get_model_for_job("visual.model.create") entry["model_id"] 'stabilityai/TripoSR'

load_catalog cached

load_catalog() -> dict[str, dict[str, Any]]

Load and return the merged job-type → model catalog.

Reads the built-in catalog.yaml first, then overlays any entries found in ~/.config/assgen/models.yaml. The result is cached; call load_catalog.cache_clear() after modifying the user catalog.

Returns:

Type Description
dict[str, dict[str, Any]]

A dict mapping job-type strings to catalog entry dicts. Each entry

dict[str, dict[str, Any]]

contains at least model_id (str | None) and name (str).

Example

catalog = load_catalog() "visual.model.create" in catalog True

Source code in src/assgen/catalog.py
@lru_cache(maxsize=1)
def load_catalog() -> dict[str, dict[str, Any]]:
    """Load and return the merged job-type → model catalog.

    Reads the built-in ``catalog.yaml`` first, then overlays any entries
    found in ``~/.config/assgen/models.yaml``.  The result is cached; call
    ``load_catalog.cache_clear()`` after modifying the user catalog.

    Returns:
        A dict mapping job-type strings to catalog entry dicts.  Each entry
        contains at least ``model_id`` (str | None) and ``name`` (str).

    Example:
        >>> catalog = load_catalog()
        >>> "visual.model.create" in catalog
        True
    """
    with _BUILTIN_CATALOG.open() as f:
        data = yaml.safe_load(f) or {}
    catalog: dict[str, Any] = data.get("catalog", {})

    user_path = get_config_dir() / "models.yaml"
    if user_path.exists():
        with user_path.open() as f:
            user_data = yaml.safe_load(f) or {}
        catalog.update(user_data.get("catalog", {}))

    return catalog

get_model_for_job

get_model_for_job(job_type: str) -> dict[str, Any] | None

Return the catalog entry for job_type, or None if unknown.

Parameters:

Name Type Description Default
job_type str

Dot-separated task identifier, e.g. "visual.model.create".

required

Returns:

Type Description
dict[str, Any] | None

A dict with keys model_id, name, task, and optional

dict[str, Any] | None

notes; or None if the job type is not in the catalog.

Example

get_model_for_job("audio.sfx.generate")["name"] 'AudioGen Medium'

Source code in src/assgen/catalog.py
def get_model_for_job(job_type: str) -> dict[str, Any] | None:
    """Return the catalog entry for *job_type*, or ``None`` if unknown.

    Args:
        job_type: Dot-separated task identifier, e.g. ``"visual.model.create"``.

    Returns:
        A dict with keys ``model_id``, ``name``, ``task``, and optional
        ``notes``; or ``None`` if the job type is not in the catalog.

    Example:
        >>> get_model_for_job("audio.sfx.generate")["name"]
        'AudioGen Medium'
    """
    return load_catalog().get(job_type)

all_job_types

all_job_types() -> list[str]

Return a sorted list of every job type in the catalog.

Returns:

Type Description
list[str]

Alphabetically sorted list of job-type strings.

Source code in src/assgen/catalog.py
def all_job_types() -> list[str]:
    """Return a sorted list of every job type in the catalog.

    Returns:
        Alphabetically sorted list of job-type strings.
    """
    return sorted(load_catalog().keys())

all_model_ids

all_model_ids() -> list[str]

Return a deduplicated list of every HF model ID referenced in the catalog.

Returns:

Type Description
list[str]

List of org/repo model ID strings, in catalog order, without

list[str]

duplicates (multiple job types often share the same base model).

Source code in src/assgen/catalog.py
def all_model_ids() -> list[str]:
    """Return a deduplicated list of every HF model ID referenced in the catalog.

    Returns:
        List of ``org/repo`` model ID strings, in catalog order, without
        duplicates (multiple job types often share the same base model).
    """
    seen: set[str] = set()
    result: list[str] = []
    for entry in load_catalog().values():
        mid = entry.get("model_id")
        if mid and mid not in seen:
            seen.add(mid)
            result.append(mid)
    return result

Database

SQLite persistence layer for jobs, models, and usage records.

db

SQLite database schema and helpers for assgen.

Tables

jobs — all submitted jobs and their lifecycle state model_usage — record of which model was used for each job (analytics) models — locally installed model metadata

Migrations are handled by a simple version table; new columns/tables are added incrementally so existing databases are upgraded automatically.

JobStatus

String constants for job lifecycle states.

Attributes:

Name Type Description
QUEUED

Job is waiting to be picked up by a worker.

RUNNING

A worker is actively processing the job.

COMPLETED

Job finished successfully.

FAILED

Job terminated with an error.

CANCELLED

Job was explicitly cancelled before completion.

TERMINAL

The set of states from which no transition is possible.

get_connection

get_connection(db_path: Path | None = None) -> sqlite3.Connection

Open a SQLite connection with WAL mode and foreign-key enforcement.

Parameters:

Name Type Description Default
db_path Path | None

Path to the SQLite file. Defaults to the platform-appropriate config directory (~/.config/assgen/assgen.db on Linux).

None

Returns:

Type Description
Connection

An open sqlite3.Connection with row_factory = sqlite3.Row,

Connection

WAL journal mode, and foreign keys enabled.

Source code in src/assgen/db.py
def get_connection(db_path: Path | None = None) -> sqlite3.Connection:
    """Open a SQLite connection with WAL mode and foreign-key enforcement.

    Args:
        db_path: Path to the SQLite file.  Defaults to the platform-appropriate
            config directory (``~/.config/assgen/assgen.db`` on Linux).

    Returns:
        An open ``sqlite3.Connection`` with ``row_factory = sqlite3.Row``,
        WAL journal mode, and foreign keys enabled.
    """
    path = db_path or get_db_path()
    conn = sqlite3.connect(str(path), check_same_thread=False)
    conn.row_factory = sqlite3.Row
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("PRAGMA foreign_keys=ON")
    return conn

transaction

transaction(conn: Connection) -> Generator[sqlite3.Connection, None, None]

Context manager that commits on success or rolls back on any exception.

Parameters:

Name Type Description Default
conn Connection

An open SQLite connection.

required

Yields:

Type Description
Connection

The same connection, for use inside a with block.

Raises:

Type Description
Exception

Re-raises whatever caused the rollback.

Source code in src/assgen/db.py
@contextmanager
def transaction(conn: sqlite3.Connection) -> Generator[sqlite3.Connection, None, None]:
    """Context manager that commits on success or rolls back on any exception.

    Args:
        conn: An open SQLite connection.

    Yields:
        The same connection, for use inside a ``with`` block.

    Raises:
        Exception: Re-raises whatever caused the rollback.
    """
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise

init_db

init_db(db_path: Path | None = None) -> sqlite3.Connection

Ensure the database is initialised and migrated; return an open connection.

Source code in src/assgen/db.py
def init_db(db_path: Path | None = None) -> sqlite3.Connection:
    """Ensure the database is initialised and migrated; return an open connection."""
    conn = get_connection(db_path)
    with transaction(conn):
        conn.executescript(_SCHEMA_V1)
        row = conn.execute("SELECT version FROM schema_version LIMIT 1").fetchone()
        current = row["version"] if row else 0
        if current == 0:
            conn.execute("INSERT INTO schema_version VALUES (?)", (1,))
            current = 1
        if current < 2:
            try:
                conn.executescript(_SCHEMA_V2_ADDITIONS)
            except sqlite3.OperationalError:
                pass  # column may already exist
            conn.execute("UPDATE schema_version SET version = 2")
    return conn

create_job

create_job(conn: Connection, job_type: str, params: dict[str, Any], priority: int = 0, tags: list[str] | None = None) -> str

Insert a new QUEUED job and return its UUID.

Parameters:

Name Type Description Default
conn Connection

An open database connection.

required
job_type str

Dot-separated task identifier, e.g. "visual.model.create".

required
params dict[str, Any]

Arbitrary key/value pairs forwarded to the inference handler.

required
priority int

Worker priority; higher values are processed first (0–100).

0
tags list[str] | None

Optional list of string labels for filtering/grouping.

None

Returns:

Type Description
str

The newly created job's UUID string.

Source code in src/assgen/db.py
def create_job(
    conn: sqlite3.Connection,
    job_type: str,
    params: dict[str, Any],
    priority: int = 0,
    tags: list[str] | None = None,
) -> str:
    """Insert a new QUEUED job and return its UUID.

    Args:
        conn: An open database connection.
        job_type: Dot-separated task identifier, e.g. ``"visual.model.create"``.
        params: Arbitrary key/value pairs forwarded to the inference handler.
        priority: Worker priority; higher values are processed first (0–100).
        tags: Optional list of string labels for filtering/grouping.

    Returns:
        The newly created job's UUID string.
    """
    job_id = str(uuid.uuid4())
    with transaction(conn):
        conn.execute(
            """
            INSERT INTO jobs (id, job_type, status, params, priority, created_at, tags)
            VALUES (?, ?, 'QUEUED', ?, ?, ?, ?)
            """,
            (job_id, job_type, json.dumps(params), priority, _now_iso(), json.dumps(tags or [])),
        )
    return job_id

get_job

get_job(conn: Connection, job_id: str) -> dict[str, Any] | None

Look up a job by exact UUID or unambiguous 8-character prefix.

Parameters:

Name Type Description Default
conn Connection

An open database connection.

required
job_id str

Full UUID or a prefix of at least 8 characters. If the prefix matches more than one job, None is returned to avoid ambiguity.

required

Returns:

Type Description
dict[str, Any] | None

A dict representation of the job row, or None if not found.

Source code in src/assgen/db.py
def get_job(conn: sqlite3.Connection, job_id: str) -> dict[str, Any] | None:
    """Look up a job by exact UUID or unambiguous 8-character prefix.

    Args:
        conn: An open database connection.
        job_id: Full UUID or a prefix of at least 8 characters.  If the prefix
            matches more than one job, ``None`` is returned to avoid ambiguity.

    Returns:
        A dict representation of the job row, or ``None`` if not found.
    """
    row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
    if row is None and len(job_id) >= 8:
        rows = conn.execute(
            "SELECT * FROM jobs WHERE id LIKE ?", (f"{job_id}%",)
        ).fetchall()
        if len(rows) == 1:
            row = rows[0]
    return _row_to_job(row) if row else None

list_jobs

list_jobs(conn: Connection, statuses: list[str] | None = None, limit: int = 50) -> list[dict[str, Any]]

Return a list of jobs, newest first.

Parameters:

Name Type Description Default
conn Connection

An open database connection.

required
statuses list[str] | None

If provided, only return jobs whose status is in this list. Use JobStatus constants, e.g. [JobStatus.QUEUED, JobStatus.RUNNING].

None
limit int

Maximum number of rows to return (default 50, max 500 via API).

50

Returns:

Type Description
list[dict[str, Any]]

A list of job dicts ordered by created_at descending.

Source code in src/assgen/db.py
def list_jobs(
    conn: sqlite3.Connection,
    statuses: list[str] | None = None,
    limit: int = 50,
) -> list[dict[str, Any]]:
    """Return a list of jobs, newest first.

    Args:
        conn: An open database connection.
        statuses: If provided, only return jobs whose status is in this list.
            Use ``JobStatus`` constants, e.g. ``[JobStatus.QUEUED, JobStatus.RUNNING]``.
        limit: Maximum number of rows to return (default 50, max 500 via API).

    Returns:
        A list of job dicts ordered by ``created_at`` descending.
    """
    if statuses:
        placeholders = ",".join("?" * len(statuses))
        rows = conn.execute(
            f"SELECT * FROM jobs WHERE status IN ({placeholders}) ORDER BY created_at DESC LIMIT ?",
            (*statuses, limit),
        ).fetchall()
    else:
        rows = conn.execute(
            "SELECT * FROM jobs ORDER BY created_at DESC LIMIT ?", (limit,)
        ).fetchall()
    return [_row_to_job(r) for r in rows]

update_job_status

update_job_status(conn: Connection, job_id: str, status: str, progress: float | None = None, progress_message: str | None = None, output: dict[str, Any] | None = None, error: str | None = None, worker_id: str | None = None) -> None

Update a job's status and optional progress/output fields.

Automatically sets started_at when transitioning to RUNNING and completed_at when transitioning to any terminal state.

Parameters:

Name Type Description Default
conn Connection

An open database connection.

required
job_id str

The job's full UUID.

required
status str

New status — use a JobStatus constant.

required
progress float | None

Completion fraction 0.0–1.0.

None
progress_message str | None

Human-readable progress description shown in the CLI bar.

None
output dict[str, Any] | None

Arbitrary result payload stored as JSON (e.g. {"path": "..."}).

None
error str | None

Error message to store when the job fails.

None
worker_id str | None

Identifier of the worker thread handling this job.

None
Source code in src/assgen/db.py
def update_job_status(
    conn: sqlite3.Connection,
    job_id: str,
    status: str,
    progress: float | None = None,
    progress_message: str | None = None,
    output: dict[str, Any] | None = None,
    error: str | None = None,
    worker_id: str | None = None,
) -> None:
    """Update a job's status and optional progress/output fields.

    Automatically sets ``started_at`` when transitioning to ``RUNNING`` and
    ``completed_at`` when transitioning to any terminal state.

    Args:
        conn: An open database connection.
        job_id: The job's full UUID.
        status: New status — use a ``JobStatus`` constant.
        progress: Completion fraction 0.0–1.0.
        progress_message: Human-readable progress description shown in the CLI bar.
        output: Arbitrary result payload stored as JSON (e.g. ``{"path": "..."}``).
        error: Error message to store when the job fails.
        worker_id: Identifier of the worker thread handling this job.
    """
    now = _now_iso()
    fields: list[str] = ["status = ?"]
    values: list[Any] = [status]

    if progress is not None:
        fields.append("progress = ?")
        values.append(progress)
    if progress_message is not None:
        fields.append("progress_message = ?")
        values.append(progress_message)
    if output is not None:
        fields.append("output = ?")
        values.append(json.dumps(output))
    if error is not None:
        fields.append("error = ?")
        values.append(error)
    if worker_id is not None:
        fields.append("worker_id = ?")
        values.append(worker_id)
    if status == JobStatus.RUNNING:
        fields.append("started_at = ?")
        values.append(now)
    if status in JobStatus.TERMINAL:
        fields.append("completed_at = ?")
        values.append(now)

    values.append(job_id)
    with transaction(conn):
        conn.execute(f"UPDATE jobs SET {', '.join(fields)} WHERE id = ?", values)

reset_stale_running_jobs

reset_stale_running_jobs(conn: Connection) -> int

Mark any RUNNING jobs as FAILED on server startup (crash recovery).

Returns the number of jobs reset.

Source code in src/assgen/db.py
def reset_stale_running_jobs(conn: sqlite3.Connection) -> int:
    """Mark any RUNNING jobs as FAILED on server startup (crash recovery).

    Returns the number of jobs reset.
    """
    with transaction(conn):
        cur = conn.execute(
            "UPDATE jobs SET status = ?, error = ?, completed_at = ? WHERE status = ?",
            (JobStatus.FAILED, "Server restarted while job was running", _now_iso(), JobStatus.RUNNING),
        )
    return cur.rowcount

upsert_model

upsert_model(conn: Connection, model_id: str, **kwargs: Any) -> None

Insert or update a model row.

Source code in src/assgen/db.py
def upsert_model(conn: sqlite3.Connection, model_id: str, **kwargs: Any) -> None:
    """Insert or update a model row."""
    existing = conn.execute(
        "SELECT model_id FROM models WHERE model_id = ?", (model_id,)
    ).fetchone()
    if existing:
        fields = ", ".join(f"{k} = ?" for k in kwargs)
        with transaction(conn):
            conn.execute(
                f"UPDATE models SET {fields} WHERE model_id = ?",
                (*kwargs.values(), model_id),
            )
    else:
        kwargs["model_id"] = model_id
        cols = ", ".join(kwargs.keys())
        placeholders = ", ".join("?" * len(kwargs))
        with transaction(conn):
            conn.execute(f"INSERT INTO models ({cols}) VALUES ({placeholders})", list(kwargs.values()))

record_model_usage

record_model_usage(conn: Connection, model_id: str, job_id: str) -> None
Source code in src/assgen/db.py
def record_model_usage(conn: sqlite3.Connection, model_id: str, job_id: str) -> None:
    with transaction(conn):
        conn.execute(
            "INSERT INTO model_usage (model_id, job_id, used_at) VALUES (?, ?, ?)",
            (model_id, job_id, _now_iso()),
        )
        conn.execute(
            "UPDATE models SET last_used_at = ? WHERE model_id = ?",
            (_now_iso(), model_id),
        )

Configuration

Platform-aware config directory resolution and YAML load/save helpers.

config

OS-agnostic configuration and path management for assgen.

Config directory follows the XDG Base Directory spec on Linux/macOS and %APPDATA% on Windows, via platformdirs.

Layout inside the config dir

client.yaml — client configuration (server URL, defaults) server.yaml — server configuration (host, port, workers, device) models.yaml — user model catalog overrides (merged with built-in catalog) assgen.db — SQLite database (jobs, model usage, etc.) server.pid — local server PID + URL (runtime, not committed) outputs/ — default output directory for generated assets

get_config_dir

get_config_dir() -> Path

Return (and create) the OS-appropriate config directory.

Source code in src/assgen/config.py
def get_config_dir() -> Path:
    """Return (and create) the OS-appropriate config directory."""
    path = Path(user_config_dir(APP_NAME, APP_AUTHOR))
    path.mkdir(parents=True, exist_ok=True)
    return path

get_db_path

get_db_path() -> Path
Source code in src/assgen/config.py
def get_db_path() -> Path:
    return get_config_dir() / "assgen.db"

get_models_cache_dir

get_models_cache_dir() -> Path
Source code in src/assgen/config.py
def get_models_cache_dir() -> Path:
    d = get_data_dir() / "models"
    d.mkdir(parents=True, exist_ok=True)
    return d

load_server_config

load_server_config() -> dict[str, Any]
Source code in src/assgen/config.py
def load_server_config() -> dict[str, Any]:
    path = get_config_dir() / "server.yaml"
    data = {**_SERVER_DEFAULTS, **_load_yaml(path)}
    # Allow env-var overrides
    if host := os.environ.get("ASSGEN_HOST"):
        data["host"] = host
    if port := os.environ.get("ASSGEN_PORT"):
        data["port"] = int(port)
    if device := os.environ.get("ASSGEN_DEVICE"):
        data["device"] = device
    return data

save_server_config

save_server_config(updates: dict[str, Any]) -> None
Source code in src/assgen/config.py
def save_server_config(updates: dict[str, Any]) -> None:
    path = get_config_dir() / "server.yaml"
    data = {**_SERVER_DEFAULTS, **_load_yaml(path), **updates}
    _save_yaml(path, data)

load_client_config

load_client_config() -> dict[str, Any]
Source code in src/assgen/config.py
def load_client_config() -> dict[str, Any]:
    path = get_config_dir() / "client.yaml"
    data = {**_CLIENT_DEFAULTS, **_load_yaml(path)}
    # Allow env-var override
    if url := os.environ.get("ASSGEN_SERVER_URL"):
        data["server_url"] = url
    return data

save_client_config

save_client_config(updates: dict[str, Any]) -> None
Source code in src/assgen/config.py
def save_client_config(updates: dict[str, Any]) -> None:
    path = get_config_dir() / "client.yaml"
    data = {**_CLIENT_DEFAULTS, **_load_yaml(path), **updates}
    _save_yaml(path, data)

Server — Model Manager

Downloads, caches, and tracks HuggingFace models on the server side.

model_manager

HuggingFace model manager.

Responsible for: - Resolving which model to use for a given job type (catalog lookup) - Downloading models from the Hub into the local cache - Tracking installed models in the SQLite database - Reporting model status (configured / downloading / installed)

ModelManager

ModelManager(conn: Connection, device: str = 'auto', server_cfg: dict | None = None, db_path: str | None = None)

Manage HuggingFace model downloads, caching, and status tracking.

One ModelManager is instantiated per server process and shared across all worker threads via the server_cfg stored in app.state.

Attributes:

Name Type Description
conn

SQLite connection used to persist model metadata.

device

Resolved device string — "cuda", "mps", or "cpu".

Initialise the manager.

Parameters:

Name Type Description Default
conn Connection

An open SQLite connection (must have row_factory = sqlite3.Row).

required
device str

Preferred device — "auto" detects CUDA/MPS/CPU at runtime.

'auto'
server_cfg dict | None

The loaded server configuration dict (used for allow-list enforcement).

None
db_path str | None

Filesystem path to the SQLite database file. Required for thread-safe writes from the worker thread.

None
Source code in src/assgen/server/model_manager.py
def __init__(
    self,
    conn: sqlite3.Connection,
    device: str = "auto",
    server_cfg: dict | None = None,
    db_path: str | None = None,
) -> None:
    """Initialise the manager.

    Args:
        conn: An open SQLite connection (must have ``row_factory = sqlite3.Row``).
        device: Preferred device — ``"auto"`` detects CUDA/MPS/CPU at runtime.
        server_cfg: The loaded server configuration dict (used for allow-list
            enforcement).
        db_path: Filesystem path to the SQLite database file. Required for
            thread-safe writes from the worker thread.
    """
    self.conn = conn
    self._db_path: str | None = db_path
    self.device = detect_device(device)
    self._cache_dir = get_models_cache_dir()
    self._server_cfg: dict = server_cfg or {}
    logger.info("ModelManager initialised", extra={"device": self.device})

ensure_model

ensure_model(model_id: str, progress_cb: ProgressCallback | None = None) -> Path

Download model if not already cached; return the local cache path.

Parameters:

Name Type Description Default
model_id str

HuggingFace model identifier in org/repo format.

required
progress_cb ProgressCallback | None

Optional (fraction: float, message: str) -> None callback for surfacing download/check progress to the caller (e.g. to forward to the client via :func:assgen.db.update_job_status).

None

Returns:

Type Description
Path

Local Path to the directory containing the downloaded model.

Raises:

Type Description
ValueError

If model_id is None or blocked by the allow-list.

Exception

Re-raised from huggingface_hub.snapshot_download on network or authentication errors.

Source code in src/assgen/server/model_manager.py
def ensure_model(
    self,
    model_id: str,
    progress_cb: ProgressCallback | None = None,
) -> Path:
    """Download model if not already cached; return the local cache path.

    Args:
        model_id: HuggingFace model identifier in ``org/repo`` format.
        progress_cb: Optional ``(fraction: float, message: str) -> None``
            callback for surfacing download/check progress to the caller
            (e.g. to forward to the client via :func:`assgen.db.update_job_status`).

    Returns:
        Local ``Path`` to the directory containing the downloaded model.

    Raises:
        ValueError: If ``model_id`` is ``None`` or blocked by the allow-list.
        Exception: Re-raised from ``huggingface_hub.snapshot_download`` on
            network or authentication errors.
    """
    def _cb(frac: float, msg: str) -> None:
        if progress_cb:
            progress_cb(frac, msg)

    if model_id is None:
        raise ValueError("model_id is None — job type may not require a model")

    # Enforce allow-list before any I/O
    from assgen.server.validation import check_allow_list
    check_allow_list(model_id, self._server_cfg)

    cache_path = self._cache_dir / _safe_name(model_id)

    if cache_path.exists() and any(cache_path.iterdir()):
        logger.info("Model already cached", extra={"model_id": model_id})
        _cb(0.15, f"Model {model_id} already cached ✓")
        return cache_path

    _cb(0.05, f"Downloading {model_id} from HuggingFace Hub…")
    logger.info(
        "Downloading model from HuggingFace Hub",
        extra={"model_id": model_id, "cache_dir": str(cache_path)},
    )
    try:
        from huggingface_hub import snapshot_download
        tqdm_cls = _make_hf_tqdm_class(_cb, start_frac=0.05, end_frac=0.18)
        dl_kwargs: dict = dict(
            repo_id=model_id,
            local_dir=str(cache_path),
            local_dir_use_symlinks=False,
            ignore_patterns=["*.msgpack", "*.h5", "flax_*"],
        )
        if tqdm_cls is not None:
            dl_kwargs["tqdm_class"] = tqdm_cls
        snapshot_download(**dl_kwargs)
    except Exception as exc:
        logger.error("Model download failed", extra={"model_id": model_id, "error": str(exc)})
        raise

    _cb(0.20, f"Model {model_id} downloaded successfully ✓")

    now = datetime.now(UTC).isoformat()
    size = _dir_size(cache_path)
    # Open a fresh connection for the write — ensure_model runs in the
    # worker thread, but self.conn was created in the startup thread.
    import sqlite3 as _s
    if self._db_path:
        _write_conn = _s.connect(self._db_path)
        _write_conn.row_factory = _s.Row
        try:
            upsert_model(
                _write_conn,
                model_id=model_id,
                local_path=str(cache_path),
                installed_at=now,
                size_bytes=size,
            )
            _write_conn.commit()
        finally:
            _write_conn.close()
    else:
        upsert_model(
            self.conn,
            model_id=model_id,
            local_path=str(cache_path),
            installed_at=now,
            size_bytes=size,
        )
    logger.info(
        "Model downloaded successfully",
        extra={"model_id": model_id, "size_bytes": size},
    )
    return cache_path

ensure_for_job_type

ensure_for_job_type(job_type: str, progress_cb: ProgressCallback | None = None) -> tuple[str | None, Path | None]

Resolve the catalog model for job_type and ensure it is cached.

Parameters:

Name Type Description Default
job_type str

Dot-separated task identifier, e.g. "visual.model.create".

required
progress_cb ProgressCallback | None

Optional progress callback forwarded to :meth:ensure_model.

None

Returns:

Type Description
str | None

A (model_id, local_path) tuple. Both elements are None if

Path | None

the job type has no associated model (e.g. pure format-conversion).

Raises:

Type Description
ValueError

If job_type is not found in the catalog.

Source code in src/assgen/server/model_manager.py
def ensure_for_job_type(
    self,
    job_type: str,
    progress_cb: ProgressCallback | None = None,
) -> tuple[str | None, Path | None]:
    """Resolve the catalog model for *job_type* and ensure it is cached.

    Args:
        job_type: Dot-separated task identifier, e.g. ``"visual.model.create"``.
        progress_cb: Optional progress callback forwarded to :meth:`ensure_model`.

    Returns:
        A ``(model_id, local_path)`` tuple.  Both elements are ``None`` if
        the job type has no associated model (e.g. pure format-conversion).

    Raises:
        ValueError: If *job_type* is not found in the catalog.
    """
    entry = get_model_for_job(job_type)
    if not entry:
        raise ValueError(f"No catalog entry for job type: {job_type!r}")
    model_id = entry.get("model_id")
    if not model_id:
        return None, None  # e.g., format-conversion tasks
    path = self.ensure_model(model_id, progress_cb=progress_cb)
    return model_id, path

list_status

list_status() -> list[dict[str, Any]]

Return status of every model in the catalog.

Source code in src/assgen/server/model_manager.py
def list_status(self) -> list[dict[str, Any]]:
    """Return status of every model in the catalog."""
    catalog = load_catalog()
    seen: dict[str, dict[str, Any]] = {}
    for job_type, entry in catalog.items():
        mid = entry.get("model_id") or "(none)"
        if mid not in seen:
            row = self.conn.execute(
                "SELECT * FROM models WHERE model_id = ?", (mid,)
            ).fetchone()
            installed = bool(row and row["local_path"] and Path(row["local_path"]).exists())
            seen[mid] = {
                "model_id": mid,
                "name": entry.get("name", mid),
                "installed": installed,
                "local_path": row["local_path"] if row else None,
                "installed_at": row["installed_at"] if row else None,
                "last_used_at": row["last_used_at"] if row else None,
                "size_bytes": row["size_bytes"] if row else None,
                "job_types": [],
            }
        seen[mid]["job_types"].append(job_type)
    return list(seen.values())

install_all

install_all() -> None

Download every model referenced in the catalog.

Source code in src/assgen/server/model_manager.py
def install_all(self) -> None:
    """Download every model referenced in the catalog."""
    for mid in set(
        e["model_id"]
        for e in load_catalog().values()
        if e.get("model_id")
    ):
        try:
            self.ensure_model(mid)
        except Exception as exc:
            logger.error("Failed to install model", extra={"model_id": mid, "error": str(exc)})

detect_device

detect_device(preference: str = 'auto') -> str
Source code in src/assgen/server/model_manager.py
def detect_device(preference: str = "auto") -> str:
    if preference != "auto":
        return preference
    try:
        import torch
        if torch.cuda.is_available():
            return "cuda"
        if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
            return "mps"
    except ImportError:
        pass
    return "cpu"

Server — Validation

Allow-list enforcement and model↔task compatibility checks.

validation

Model↔task compatibility validation and server allow-list enforcement.

When a job is submitted the server: 1. Checks the model_id is on the allow_list (if one is configured). 2. Fetches the model's pipeline_tag from the HuggingFace Hub API. 3. Verifies that tag is compatible with the catalog task for this job type.

Both checks can be bypassed by setting skip_model_validation: true in server.yaml (server admin opt-out). The allow-list is always enforced unless it is empty (empty = "allow everything").

HF Hub API is queried via a lightweight HTTP call so that the inference extras (torch, transformers, …) are not required just for validation.

TASK_COMPATIBLE_TAGS module-attribute

TASK_COMPATIBLE_TAGS: dict[str, frozenset[str]] = {'text-to-image': frozenset({'text-to-image'}), 'image-to-image': frozenset({'image-to-image', 'text-to-image'}), 'inpainting': frozenset({'image-to-image', 'text-to-image'}), 'image-to-3d': frozenset({'image-to-3d', 'text-to-3d'}), 'text-to-3d': frozenset({'text-to-3d', 'image-to-3d'}), 'image-to-3dgs': frozenset({'image-to-3d', 'text-to-3d', 'image-to-3dgs'}), 'mesh-retopology': frozenset({'image-to-3d', 'text-to-3d', 'mesh-retopology'}), 'uv-unwrap': frozenset({'image-to-3d', 'text-to-3d', 'uv-unwrap'}), 'texture-generation': frozenset({'text-to-image', 'image-to-image', 'texture-generation'}), 'texture-bake': frozenset({'texture-bake', 'text-to-image', 'image-to-image'}), 'auto-rig': frozenset({'auto-rig', 'image-to-3d', 'object-detection', 'image-classification'}), 'skeleton-rig': frozenset({'skeleton-rig', 'image-to-3d'}), 'motion-retarget': frozenset({'motion-retarget', 'skeleton-rig'}), 'text-to-animation': frozenset({'text-to-motion', 'text-to-video', 'text-to-animation', 'animation-generate'}), 'text-to-motion': frozenset({'text-to-motion', 'text-to-video'}), 'video-to-motion': frozenset({'video-to-motion', 'video-classification'}), 'video-to-pose': frozenset({'video-classification', 'image-classification', 'video-to-pose', 'pose-estimation'}), 'animation-generate': frozenset({'text-to-motion', 'animation-generate'}), 'text-to-audio': frozenset({'text-to-audio', 'audio-to-audio', 'text-to-speech'}), 'text-to-music': frozenset({'text-to-audio', 'audio-generation', 'music-generation'}), 'audio-generation': frozenset({'text-to-audio', 'audio-generation'}), 'music-generation': frozenset({'text-to-audio', 'audio-generation', 'music-generation'}), 'audio-to-audio': frozenset({'audio-to-audio', 'text-to-audio'}), 'automatic-speech-recognition': frozenset({'automatic-speech-recognition'}), 'text-to-speech': frozenset({'text-to-speech', 'text-to-audio'}), 'voice-clone': frozenset({'text-to-speech', 'voice-conversion', 'voice-clone', 'text-to-audio'}), 'text-to-video': frozenset({'text-to-video'}), 'image-to-video': frozenset({'image-to-video', 'text-to-video'}), 'text-to-panorama': frozenset({'text-to-image', 'text-to-panorama', 'text-to-3d', 'image-to-image'}), 'collision-mesh': frozenset({'collision-mesh', 'image-to-3d', 'text-to-3d'}), 'mesh-export': frozenset({'mesh-export'}), 'keypoint-detection': frozenset({'keypoint-detection', 'image-to-image', 'image-classification'}), 'text-generation': frozenset({'text-generation', 'text2text-generation'}), 'translation': frozenset({'translation', 'text2text-generation'}), 'question-answering': frozenset({'question-answering'}), 'feature-extraction': frozenset({'feature-extraction'}), 'depth-estimation': frozenset({'depth-estimation'}), 'object-detection': frozenset({'object-detection'}), 'image-segmentation': frozenset({'image-segmentation'})}

check_allow_list

check_allow_list(model_id: str, server_cfg: dict[str, Any]) -> None

Raise ValueError if model_id is not on the configured allow list.

An empty (or absent) allow_list means all models are permitted.

Source code in src/assgen/server/validation.py
def check_allow_list(model_id: str, server_cfg: dict[str, Any]) -> None:
    """Raise ``ValueError`` if *model_id* is not on the configured allow list.

    An empty (or absent) allow_list means *all* models are permitted.
    """
    allow_list: list[str] = server_cfg.get("allow_list") or []
    if not allow_list:
        return  # open policy — everything is allowed
    if model_id not in allow_list:
        raise ValueError(
            f"Model '{model_id}' is not on the server allow_list. "
            "Ask the server administrator to add it, or set allow_list: [] to allow all models."
        )

fetch_hf_pipeline_tag

fetch_hf_pipeline_tag(model_id: str) -> str | None

Return the HF pipeline_tag for model_id, or None on failure.

Uses a lightweight HTTP call to the HF Hub REST API — no heavy ML dependencies required.

Source code in src/assgen/server/validation.py
def fetch_hf_pipeline_tag(model_id: str) -> str | None:
    """Return the HF ``pipeline_tag`` for *model_id*, or ``None`` on failure.

    Uses a lightweight HTTP call to the HF Hub REST API — no heavy
    ML dependencies required.
    """
    try:
        import httpx
        url = f"https://huggingface.co/api/models/{model_id}?fields=pipeline_tag"
        resp = httpx.get(url, timeout=10.0, follow_redirects=True)
        if resp.status_code == 200:
            tag = resp.json().get("pipeline_tag")
            logger.debug("HF pipeline_tag for %s: %s", model_id, tag)
            return tag
        logger.warning(
            "HF Hub API returned %d for model %s — skipping tag validation",
            resp.status_code, model_id,
        )
    except Exception as exc:
        logger.warning("Could not fetch HF pipeline_tag for %s: %s", model_id, exc)
    return None

validate_model_task_compatibility

validate_model_task_compatibility(model_id: str, catalog_task: str | None, server_cfg: dict[str, Any]) -> tuple[bool, str]

Check whether model_id is suitable for catalog_task.

Returns (ok, reason) where reason is a human-readable explanation when ok is False.

Validation is skipped (returns (True, "skipped") ) when: - server_cfg["skip_model_validation"] is truthy - catalog_task is None (job type has no associated HF task) - The task is not in TASK_COMPATIBLE_TAGS (unknown/custom task) - The HF Hub API returns no pipeline_tag for the model

Source code in src/assgen/server/validation.py
def validate_model_task_compatibility(
    model_id: str,
    catalog_task: str | None,
    server_cfg: dict[str, Any],
) -> tuple[bool, str]:
    """Check whether *model_id* is suitable for *catalog_task*.

    Returns ``(ok, reason)`` where *reason* is a human-readable explanation
    when ``ok`` is ``False``.

    Validation is skipped (returns ``(True, "skipped")`` ) when:
    - ``server_cfg["skip_model_validation"]`` is truthy
    - ``catalog_task`` is ``None`` (job type has no associated HF task)
    - The task is not in ``TASK_COMPATIBLE_TAGS`` (unknown/custom task)
    - The HF Hub API returns no pipeline_tag for the model
    """
    if server_cfg.get("skip_model_validation"):
        return True, "validation skipped (server config)"

    if not catalog_task:
        return True, "no task constraint for this job type"

    compatible = TASK_COMPATIBLE_TAGS.get(catalog_task)
    if compatible is None:
        logger.debug(
            "Task '%s' not in TASK_COMPATIBLE_TAGS — skipping validation", catalog_task
        )
        return True, f"task '{catalog_task}' has no compatibility rules defined"

    pipeline_tag = fetch_hf_pipeline_tag(model_id)
    if pipeline_tag is None:
        logger.warning(
            "Could not determine pipeline_tag for %s — allowing by default", model_id
        )
        return True, "could not fetch pipeline_tag from HF Hub — allowed by default"

    if pipeline_tag in compatible:
        return True, f"compatible ({pipeline_tag}{sorted(compatible)})"

    return False, (
        f"Model '{model_id}' has pipeline_tag='{pipeline_tag}' which is not "
        f"compatible with task '{catalog_task}'. "
        f"Expected one of: {sorted(compatible)}. "
        "Set skip_model_validation: true in server.yaml to override."
    )

validate_job_model

validate_job_model(model_id: str, catalog_task: str | None, server_cfg: dict[str, Any]) -> None

Run all validations for a given model_id + catalog_task pair.

Raises ValueError with a descriptive message on failure.

Source code in src/assgen/server/validation.py
def validate_job_model(
    model_id: str,
    catalog_task: str | None,
    server_cfg: dict[str, Any],
) -> None:
    """Run all validations for a given *model_id* + *catalog_task* pair.

    Raises ``ValueError`` with a descriptive message on failure.
    """
    check_allow_list(model_id, server_cfg)
    ok, reason = validate_model_task_compatibility(model_id, catalog_task, server_cfg)
    if not ok:
        raise ValueError(reason)
    logger.debug("Model validation passed for %s: %s", model_id, reason)

Version

version

Version introspection for assgen.

Canonical approach
  1. importlib.metadata.version("assgen") is the installed version — the single source of truth. hatch-vcs writes this from the git tag at pip install / pip install -e . time.
  2. git describe --tags --long --dirty surfaces the current source state so you can tell whether the working tree has changed since the install.
  3. The --version / -V flag on both CLIs combines these two pieces so you always know exactly what code is running.

Version string examples

  • Production install from a tagged release wheel::

    assgen 0.1.0

  • Editable install from a clean dev checkout (16 commits after v0.0.1)::

    assgen 0.0.2.dev16+gc9ee176 source v0.0.1-16-gc9ee176 (clean) python 3.12.2

  • Editable install with uncommitted changes in the working tree::

    assgen 0.0.2.dev16+gc9ee176 source v0.0.1-16-gc9ee176-dirty ⚠ uncommitted changes python 3.12.2

get_version_info cached

get_version_info() -> dict[str, str | None]

Return a dict with version, git_describe, dirty, and python fields.

  • version — installed package version from :mod:importlib.metadata.
  • git_describe — output of git describe --tags --long --dirty, or None when git is unavailable (e.g. running from a wheel install).
  • dirtyTrue if the git working tree has uncommitted changes.
  • python — Python version string.
Source code in src/assgen/version.py
@lru_cache(maxsize=1)
def get_version_info() -> dict[str, str | None]:
    """Return a dict with ``version``, ``git_describe``, ``dirty``, and ``python`` fields.

    * ``version``      — installed package version from :mod:`importlib.metadata`.
    * ``git_describe`` — output of ``git describe --tags --long --dirty``, or *None*
                         when git is unavailable (e.g. running from a wheel install).
    * ``dirty``        — ``True`` if the git working tree has uncommitted changes.
    * ``python``       — Python version string.
    """
    version = _installed_version()
    git_desc = _git_describe()
    dirty = git_desc.endswith("-dirty") if git_desc else False
    return {
        "version": version,
        "git_describe": git_desc,
        "dirty": dirty,
        "python": sys.version.split()[0],
    }

format_version_string

format_version_string(name: str = 'assgen') -> str

Return a human-readable version string for --version output.

Source code in src/assgen/version.py
def format_version_string(name: str = "assgen") -> str:
    """Return a human-readable version string for ``--version`` output."""
    info = get_version_info()
    ver = info["version"] or "0.0.0.dev"
    lines = [f"{name} {ver}"]

    git_desc = info.get("git_describe")
    if git_desc:
        dirty_note = "  ⚠  uncommitted changes" if info["dirty"] else " (clean)"
        lines.append(f"  source  {git_desc}{dirty_note}")

    lines.append(f"  python  {info['python']}")
    return "\n".join(lines)