Skip to content

Litellm Model

LiteLLM Model class

Full source code
import json
import logging
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal

import litellm
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_not_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.cache_control import set_cache_control

logger = logging.getLogger("litellm_model")


@dataclass
class LitellmModelConfig:
    model_name: str
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
    set_cache_control: Literal["default_end"] | None = None
    """Set explicit cache control markers, for example for Anthropic models"""


class LitellmModel:
    def __init__(self, *, config_class: type = LitellmModelConfig, **kwargs):
        self.config = config_class(**kwargs)
        self.cost = 0.0
        self.n_calls = 0
        if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
            litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))

    @retry(
        stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
        wait=wait_exponential(multiplier=1, min=4, max=60),
        before_sleep=before_sleep_log(logger, logging.WARNING),
        retry=retry_if_not_exception_type(
            (
                litellm.exceptions.UnsupportedParamsError,
                litellm.exceptions.NotFoundError,
                litellm.exceptions.PermissionDeniedError,
                litellm.exceptions.ContextWindowExceededError,
                litellm.exceptions.APIError,
                litellm.exceptions.AuthenticationError,
                KeyboardInterrupt,
            )
        ),
    )
    def _query(self, messages: list[dict[str, str]], **kwargs):
        try:
            return litellm.completion(
                model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
            )
        except litellm.exceptions.AuthenticationError as e:
            e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
            raise e

    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
        if self.config.set_cache_control:
            messages = set_cache_control(messages, mode=self.config.set_cache_control)
        response = self._query(messages, **kwargs)
        try:
            cost = litellm.cost_calculator.completion_cost(response)
        except Exception as e:
            logger.critical(
                f"Error calculating cost for model {self.config.model_name}: {e}. "
                "Please check the 'Updating the model registry' section in the documentation at "
                "https://klieret.short.gy/litellm-model-registry Still stuck? Please open a github issue for help!"
            )
            raise
        self.n_calls += 1
        assert cost >= 0.0, f"Cost is negative: {cost}"
        self.cost += cost
        GLOBAL_MODEL_STATS.add(cost)
        return {
            "content": response.choices[0].message.content or "",  # type: ignore
            "extra": {
                "response": response.model_dump(),
            },
        }

    def get_template_vars(self) -> dict[str, Any]:
        return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}

Guides

  • Setting up most models is covered in the quickstart guide.
  • If you want to use local models, please check this guide.

minisweagent.models.litellm_model

logger module-attribute

logger = getLogger('litellm_model')

LitellmModelConfig dataclass

LitellmModelConfig(
    model_name: str,
    model_kwargs: dict[str, Any] = dict(),
    litellm_model_registry: Path | str | None = getenv(
        "LITELLM_MODEL_REGISTRY_PATH"
    ),
    set_cache_control: Literal["default_end"] | None = None,
)

model_name instance-attribute

model_name: str

model_kwargs class-attribute instance-attribute

model_kwargs: dict[str, Any] = field(default_factory=dict)

litellm_model_registry class-attribute instance-attribute

litellm_model_registry: Path | str | None = getenv(
    "LITELLM_MODEL_REGISTRY_PATH"
)

set_cache_control class-attribute instance-attribute

set_cache_control: Literal['default_end'] | None = None

Set explicit cache control markers, for example for Anthropic models

LitellmModel

LitellmModel(
    *, config_class: type = LitellmModelConfig, **kwargs
)
Source code in src/minisweagent/models/litellm_model.py
33
34
35
36
37
38
def __init__(self, *, config_class: type = LitellmModelConfig, **kwargs):
    self.config = config_class(**kwargs)
    self.cost = 0.0
    self.n_calls = 0
    if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
        litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))

config instance-attribute

config = config_class(**kwargs)

cost instance-attribute

cost = 0.0

n_calls instance-attribute

n_calls = 0

query

query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/litellm_model.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
    if self.config.set_cache_control:
        messages = set_cache_control(messages, mode=self.config.set_cache_control)
    response = self._query(messages, **kwargs)
    try:
        cost = litellm.cost_calculator.completion_cost(response)
    except Exception as e:
        logger.critical(
            f"Error calculating cost for model {self.config.model_name}: {e}. "
            "Please check the 'Updating the model registry' section in the documentation at "
            "https://klieret.short.gy/litellm-model-registry Still stuck? Please open a github issue for help!"
        )
        raise
    self.n_calls += 1
    assert cost >= 0.0, f"Cost is negative: {cost}"
    self.cost += cost
    GLOBAL_MODEL_STATS.add(cost)
    return {
        "content": response.choices[0].message.content or "",  # type: ignore
        "extra": {
            "response": response.model_dump(),
        },
    }

get_template_vars

get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/litellm_model.py
89
90
def get_template_vars(self) -> dict[str, Any]:
    return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}