Litellm Model
LiteLLM Model class
Full source code
import json
import logging
import os
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal
import litellm
from tenacity import (
before_sleep_log,
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)
from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.cache_control import set_cache_control
logger = logging.getLogger("litellm_model")
@dataclass
class LitellmModelConfig:
model_name: str
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
set_cache_control: Literal["default_end"] | None = None
"""Set explicit cache control markers, for example for Anthropic models"""
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
class LitellmModel:
def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
self.config = config_class(**kwargs)
self.cost = 0.0
self.n_calls = 0
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
@retry(
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
wait=wait_exponential(multiplier=1, min=4, max=60),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_not_exception_type(
(
litellm.exceptions.UnsupportedParamsError,
litellm.exceptions.NotFoundError,
litellm.exceptions.PermissionDeniedError,
litellm.exceptions.ContextWindowExceededError,
litellm.exceptions.APIError,
litellm.exceptions.AuthenticationError,
KeyboardInterrupt,
)
),
)
def _query(self, messages: list[dict[str, str]], **kwargs):
try:
return litellm.completion(
model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
)
except litellm.exceptions.AuthenticationError as e:
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
raise e
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
if self.config.set_cache_control:
messages = set_cache_control(messages, mode=self.config.set_cache_control)
response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
try:
cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
if cost <= 0.0:
raise ValueError(f"Cost must be > 0.0, got {cost}")
except Exception as e:
cost = 0.0
if self.config.cost_tracking != "ignore_errors":
msg = (
f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
"You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
"globally with export MSWEA_COST_TRACKING='ignore_errors'. "
"Alternatively check the 'Cost tracking' section in the documentation at "
"https://klieret.short.gy/mini-local-models. "
" Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
)
logger.critical(msg)
raise RuntimeError(msg) from e
self.n_calls += 1
self.cost += cost
GLOBAL_MODEL_STATS.add(cost)
return {
"content": response.choices[0].message.content or "", # type: ignore
"extra": {
"response": response.model_dump(),
},
}
def get_template_vars(self) -> dict[str, Any]:
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
Guides
- Setting up most models is covered in the quickstart guide.
- If you want to use local models, please check this guide.
minisweagent.models.litellm_model
logger
module-attribute
logger = getLogger('litellm_model')
LitellmModelConfig
dataclass
LitellmModelConfig(
model_name: str,
model_kwargs: dict[str, Any] = dict(),
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
),
set_cache_control: Literal["default_end"] | None = None,
cost_tracking: Literal[
"default", "ignore_errors"
] = getenv("MSWEA_COST_TRACKING", "default"),
)
model_name
instance-attribute
model_name: str
model_kwargs
class-attribute
instance-attribute
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry
class-attribute
instance-attribute
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
)
set_cache_control
class-attribute
instance-attribute
set_cache_control: Literal['default_end'] | None = None
Set explicit cache control markers, for example for Anthropic models
cost_tracking
class-attribute
instance-attribute
cost_tracking: Literal["default", "ignore_errors"] = getenv(
"MSWEA_COST_TRACKING", "default"
)
Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)
LitellmModel
LitellmModel(
*, config_class: Callable = LitellmModelConfig, **kwargs
)
Source code in src/minisweagent/models/litellm_model.py
36 37 38 39 40 41 | |
config
instance-attribute
config = config_class(**kwargs)
cost
instance-attribute
cost = 0.0
n_calls
instance-attribute
n_calls = 0
query
query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/litellm_model.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
get_template_vars
get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/litellm_model.py
99 100 | |