Portkey Model
Portkey Model class
Full source code
import json
import logging
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal
import litellm
from tenacity import (
before_sleep_log,
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)
from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.cache_control import set_cache_control
logger = logging.getLogger("portkey_model")
try:
from portkey_ai import Portkey
except ImportError:
Portkey = None
@dataclass
class PortkeyModelConfig:
model_name: str
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
"""We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
Note that this might change if we get better support for Portkey and change how we calculate costs.
"""
litellm_model_name_override: str = ""
"""We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
doesn't match the Portkey model name.
Note that this might change if we get better support for Portkey and change how we calculate costs.
"""
set_cache_control: Literal["default_end"] | None = None
"""Set explicit cache control markers, for example for Anthropic models"""
class PortkeyModel:
def __init__(self, **kwargs):
if Portkey is None:
raise ImportError(
"The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
)
self.config = PortkeyModelConfig(**kwargs)
self.cost = 0.0
self.n_calls = 0
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
# Get API key from environment or raise error
self._api_key = os.getenv("PORTKEY_API_KEY")
if not self._api_key:
raise ValueError(
"Portkey API key is required. Set it via the "
"PORTKEY_API_KEY environment variable. You can permanently set it with "
"`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
)
# Get virtual key from environment
virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
# Initialize Portkey client
client_kwargs = {"api_key": self._api_key}
if virtual_key:
client_kwargs["virtual_key"] = virtual_key
self.client = Portkey(**client_kwargs)
@retry(
stop=stop_after_attempt(int(os.getenv("MSWEA_MODEL_RETRY_STOP_AFTER_ATTEMPT", "10"))),
wait=wait_exponential(multiplier=1, min=4, max=60),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_not_exception_type((KeyboardInterrupt, TypeError, ValueError)),
)
def _query(self, messages: list[dict[str, str]], **kwargs):
# return self.client.with_options(metadata={"request_id": request_id}).chat.completions.create(
return self.client.chat.completions.create(
model=self.config.model_name,
messages=messages,
**(self.config.model_kwargs | kwargs),
)
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
if self.config.set_cache_control:
messages = set_cache_control(messages, mode=self.config.set_cache_control)
response = self._query(messages, **kwargs)
response_for_cost_calc = response.model_copy()
if self.config.litellm_model_name_override:
if response_for_cost_calc.model:
response_for_cost_calc.model = self.config.litellm_model_name_override
prompt_tokens = response_for_cost_calc.usage.prompt_tokens
if prompt_tokens is None:
logger.warning(
f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
)
prompt_tokens = 0
total_tokens = response_for_cost_calc.usage.total_tokens
completion_tokens = response_for_cost_calc.usage.completion_tokens
if completion_tokens is None:
logger.warning(
f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
)
completion_tokens = 0
if total_tokens - prompt_tokens - completion_tokens != 0:
# This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
logger.warning(
f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
" This is probably a portkey bug or incompatibility with litellm cost tracking. "
"Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
f"Full response: {response_for_cost_calc.model_dump()}"
)
response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
try:
cost = litellm.cost_calculator.completion_cost(
response_for_cost_calc, model=self.config.litellm_model_name_override or None
)
except Exception as e:
logger.critical(
f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
"Please check the 'Updating the model registry' section in the documentation at "
"https://klieret.short.gy/litellm-model-registry Still stuck? Please open a github issue for help!"
)
raise
assert cost >= 0.0, f"Cost is negative: {cost}"
self.n_calls += 1
self.cost += cost
GLOBAL_MODEL_STATS.add(cost)
return {
"content": response.choices[0].message.content or "",
"extra": {
"response": response.model_dump(),
"cost": cost,
},
}
def get_template_vars(self) -> dict[str, Any]:
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
Guide
Setting up Portkey models is covered in the quickstart guide.
minisweagent.models.portkey_model
logger
module-attribute
logger = getLogger('portkey_model')
PortkeyModelConfig
dataclass
PortkeyModelConfig(
model_name: str,
model_kwargs: dict[str, Any] = dict(),
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
),
litellm_model_name_override: str = "",
set_cache_control: Literal["default_end"] | None = None,
)
model_name
instance-attribute
model_name: str
model_kwargs
class-attribute
instance-attribute
model_kwargs: dict[str, Any] = field(default_factory=dict)
litellm_model_registry
class-attribute
instance-attribute
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
)
We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry. Note that this might change if we get better support for Portkey and change how we calculate costs.
litellm_model_name_override
class-attribute
instance-attribute
litellm_model_name_override: str = ''
We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it doesn't match the Portkey model name. Note that this might change if we get better support for Portkey and change how we calculate costs.
set_cache_control
class-attribute
instance-attribute
set_cache_control: Literal['default_end'] | None = None
Set explicit cache control markers, for example for Anthropic models
PortkeyModel
PortkeyModel(**kwargs)
Source code in src/minisweagent/models/portkey_model.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | |
cost
instance-attribute
cost = 0.0
n_calls
instance-attribute
n_calls = 0
client
instance-attribute
client = Portkey(**client_kwargs)
query
query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/portkey_model.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | |
get_template_vars
get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/portkey_model.py
145 146 | |