Litellm Model
LiteLLM Model class
Full source code
import json
import logging
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal
import litellm
from tenacity import (
before_sleep_log,
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)
from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.cache_control import set_cache_control
logger = logging.getLogger("litellm_model")
@dataclass
class LitellmModelConfig:
model_name: str
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
set_cache_control: Literal["default_end"] | None = None
"""Set explicit cache control markers, for example for Anthropic models"""
class LitellmModel:
def __init__(self, *, config_class: type = LitellmModelConfig, **kwargs):
self.config = config_class(**kwargs)
self.cost = 0.0
self.n_calls = 0
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
@retry(
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
wait=wait_exponential(multiplier=1, min=4, max=60),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_not_exception_type(
(
litellm.exceptions.UnsupportedParamsError,
litellm.exceptions.NotFoundError,
litellm.exceptions.PermissionDeniedError,
litellm.exceptions.ContextWindowExceededError,
litellm.exceptions.APIError,
litellm.exceptions.AuthenticationError,
KeyboardInterrupt,
)
),
)
def _query(self, messages: list[dict[str, str]], **kwargs):
try:
return litellm.completion(
model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
)
except litellm.exceptions.AuthenticationError as e:
e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
raise e
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
if self.config.set_cache_control:
messages = set_cache_control(messages, mode=self.config.set_cache_control)
response = self._query(messages, **kwargs)
try:
cost = litellm.cost_calculator.completion_cost(response)
except Exception as e:
logger.critical(
f"Error calculating cost for model {self.config.model_name}: {e}. "
"Please check the 'Updating the model registry' section in the documentation at "
"https://klieret.short.gy/litellm-model-registry Still stuck? Please open a github issue for help!"
)
raise
self.n_calls += 1
assert cost >= 0.0, f"Cost is negative: {cost}"
self.cost += cost
GLOBAL_MODEL_STATS.add(cost)
return {
"content": response.choices[0].message.content or "", # type: ignore
"extra": {
"response": response.model_dump(),
},
}
def get_template_vars(self) -> dict[str, Any]:
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
Guides
- Setting up most models is covered in the quickstart guide.
- If you want to use local models, please check this guide.
minisweagent.models.litellm_model
logger
module-attribute
logger = getLogger('litellm_model')
LitellmModelConfig
dataclass
LitellmModelConfig(
model_name: str,
model_kwargs: dict[str, Any] = dict(),
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
),
set_cache_control: Literal["default_end"] | None = None,
)
model_name
instance-attribute
model_name: str
model_kwargs
class-attribute
instance-attribute
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry
class-attribute
instance-attribute
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
)
set_cache_control
class-attribute
instance-attribute
set_cache_control: Literal['default_end'] | None = None
Set explicit cache control markers, for example for Anthropic models
LitellmModel
LitellmModel(
*, config_class: type = LitellmModelConfig, **kwargs
)
Source code in src/minisweagent/models/litellm_model.py
33 34 35 36 37 38 | |
config
instance-attribute
config = config_class(**kwargs)
cost
instance-attribute
cost = 0.0
n_calls
instance-attribute
n_calls = 0
query
query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/litellm_model.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
get_template_vars
get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/litellm_model.py
89 90 | |