Skip to content

Litellm Model

LiteLLM Model class

Full source code
import json
import logging
import os
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal

import litellm
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_not_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.cache_control import set_cache_control

logger = logging.getLogger("litellm_model")


@dataclass
class LitellmModelConfig:
    model_name: str
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
    set_cache_control: Literal["default_end"] | None = None
    """Set explicit cache control markers, for example for Anthropic models"""
    cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
    """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""


class LitellmModel:
    def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
        self.config = config_class(**kwargs)
        self.cost = 0.0
        self.n_calls = 0
        if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
            litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))

    @retry(
        stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
        wait=wait_exponential(multiplier=1, min=4, max=60),
        before_sleep=before_sleep_log(logger, logging.WARNING),
        retry=retry_if_not_exception_type(
            (
                litellm.exceptions.UnsupportedParamsError,
                litellm.exceptions.NotFoundError,
                litellm.exceptions.PermissionDeniedError,
                litellm.exceptions.ContextWindowExceededError,
                litellm.exceptions.APIError,
                litellm.exceptions.AuthenticationError,
                KeyboardInterrupt,
            )
        ),
    )
    def _query(self, messages: list[dict[str, str]], **kwargs):
        try:
            return litellm.completion(
                model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
            )
        except litellm.exceptions.AuthenticationError as e:
            e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
            raise e

    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
        if self.config.set_cache_control:
            messages = set_cache_control(messages, mode=self.config.set_cache_control)
        response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
        try:
            cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
            if cost <= 0.0:
                raise ValueError(f"Cost must be > 0.0, got {cost}")
        except Exception as e:
            cost = 0.0
            if self.config.cost_tracking != "ignore_errors":
                msg = (
                    f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
                    "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
                    "globally with export MSWEA_COST_TRACKING='ignore_errors'. "
                    "Alternatively check the 'Cost tracking' section in the documentation at "
                    "https://klieret.short.gy/mini-local-models. "
                    " Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
                )
                logger.critical(msg)
                raise RuntimeError(msg) from e
        self.n_calls += 1
        self.cost += cost
        GLOBAL_MODEL_STATS.add(cost)
        return {
            "content": response.choices[0].message.content or "",  # type: ignore
            "extra": {
                "response": response.model_dump(),
            },
        }

    def get_template_vars(self) -> dict[str, Any]:
        return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}

Guides

  • Setting up most models is covered in the quickstart guide.
  • If you want to use local models, please check this guide.

minisweagent.models.litellm_model

logger module-attribute

logger = getLogger('litellm_model')

LitellmModelConfig dataclass

LitellmModelConfig(
    model_name: str,
    model_kwargs: dict[str, Any] = dict(),
    litellm_model_registry: Path | str | None = getenv(
        "LITELLM_MODEL_REGISTRY_PATH"
    ),
    set_cache_control: Literal["default_end"] | None = None,
    cost_tracking: Literal[
        "default", "ignore_errors"
    ] = getenv("MSWEA_COST_TRACKING", "default"),
)

model_name instance-attribute

model_name: str

model_kwargs class-attribute instance-attribute

model_kwargs: dict[str, Any] = field(default_factory=dict)

litellm_model_registry class-attribute instance-attribute

litellm_model_registry: Path | str | None = getenv(
    "LITELLM_MODEL_REGISTRY_PATH"
)

set_cache_control class-attribute instance-attribute

set_cache_control: Literal['default_end'] | None = None

Set explicit cache control markers, for example for Anthropic models

cost_tracking class-attribute instance-attribute

cost_tracking: Literal["default", "ignore_errors"] = getenv(
    "MSWEA_COST_TRACKING", "default"
)

Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)

LitellmModel

LitellmModel(
    *, config_class: Callable = LitellmModelConfig, **kwargs
)
Source code in src/minisweagent/models/litellm_model.py
36
37
38
39
40
41
def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
    self.config = config_class(**kwargs)
    self.cost = 0.0
    self.n_calls = 0
    if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
        litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))

config instance-attribute

config = config_class(**kwargs)

cost instance-attribute

cost = 0.0

n_calls instance-attribute

n_calls = 0

query

query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/litellm_model.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
    if self.config.set_cache_control:
        messages = set_cache_control(messages, mode=self.config.set_cache_control)
    response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
    try:
        cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
        if cost <= 0.0:
            raise ValueError(f"Cost must be > 0.0, got {cost}")
    except Exception as e:
        cost = 0.0
        if self.config.cost_tracking != "ignore_errors":
            msg = (
                f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
                "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
                "globally with export MSWEA_COST_TRACKING='ignore_errors'. "
                "Alternatively check the 'Cost tracking' section in the documentation at "
                "https://klieret.short.gy/mini-local-models. "
                " Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
            )
            logger.critical(msg)
            raise RuntimeError(msg) from e
    self.n_calls += 1
    self.cost += cost
    GLOBAL_MODEL_STATS.add(cost)
    return {
        "content": response.choices[0].message.content or "",  # type: ignore
        "extra": {
            "response": response.model_dump(),
        },
    }

get_template_vars

get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/litellm_model.py
 99
100
def get_template_vars(self) -> dict[str, Any]:
    return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}