Portkey Model
Portkey Model class
Full source code
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Literal
import litellm
from pydantic import BaseModel
from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.actions_toolcall import (
BASH_TOOL,
format_toolcall_observation_messages,
parse_toolcall_actions,
)
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
from minisweagent.models.utils.cache_control import set_cache_control
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
from minisweagent.models.utils.retry import retry
logger = logging.getLogger("portkey_model")
try:
from portkey_ai import Portkey
except ImportError:
raise ImportError(
"The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
)
class PortkeyModelConfig(BaseModel):
model_name: str
model_kwargs: dict[str, Any] = {}
provider: str = ""
"""The LLM provider to use (e.g., 'openai', 'anthropic', 'google').
If not specified, will be auto-detected from model_name.
Required by Portkey when not using a virtual key.
"""
litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
"""We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
Note that this might change if we get better support for Portkey and change how we calculate costs.
"""
litellm_model_name_override: str = ""
"""We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
doesn't match the Portkey model name.
Note that this might change if we get better support for Portkey and change how we calculate costs.
"""
set_cache_control: Literal["default_end"] | None = None
"""Set explicit cache control markers, for example for Anthropic models"""
cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
"""Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
format_error_template: str = "{{ error }}"
"""Template used when the LM's output is not in the expected format."""
observation_template: str = (
"{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
"<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
)
"""Template used to render the observation after executing an action."""
multimodal_regex: str = ""
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
class PortkeyModel:
abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
self.config = config_class(**kwargs)
if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
self._api_key = os.getenv("PORTKEY_API_KEY")
if not self._api_key:
raise ValueError(
"Portkey API key is required. Set it via the "
"PORTKEY_API_KEY environment variable. You can permanently set it with "
"`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
)
virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
client_kwargs = {"api_key": self._api_key}
if virtual_key:
client_kwargs["virtual_key"] = virtual_key
elif self.config.provider:
# If no virtual key but provider is specified, pass it
client_kwargs["provider"] = self.config.provider
self.client = Portkey(**client_kwargs)
def _query(self, messages: list[dict[str, str]], **kwargs):
return self.client.chat.completions.create(
model=self.config.model_name,
messages=messages,
tools=[BASH_TOOL],
**(self.config.model_kwargs | kwargs),
)
def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
prepared = _reorder_anthropic_thinking_blocks(prepared)
return set_cache_control(prepared, mode=self.config.set_cache_control)
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
with attempt:
response = self._query(self._prepare_messages_for_api(messages), **kwargs)
cost_output = self._calculate_cost(response)
GLOBAL_MODEL_STATS.add(cost_output["cost"])
message = response.choices[0].message.model_dump()
message["extra"] = {
"actions": self._parse_actions(response),
"response": response.model_dump(),
**cost_output,
"timestamp": time.time(),
}
return message
def _parse_actions(self, response) -> list[dict]:
"""Parse tool calls from the response. Raises FormatError if unknown tool."""
tool_calls = response.choices[0].message.tool_calls or []
return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
def format_message(self, **kwargs) -> dict:
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
def format_observation_messages(
self, message: dict, outputs: list[dict], template_vars: dict | None = None
) -> list[dict]:
"""Format execution outputs into tool result messages."""
actions = message.get("extra", {}).get("actions", [])
return format_toolcall_observation_messages(
actions=actions,
outputs=outputs,
observation_template=self.config.observation_template,
template_vars=template_vars,
multimodal_regex=self.config.multimodal_regex,
)
def get_template_vars(self, **kwargs) -> dict[str, Any]:
return self.config.model_dump()
def serialize(self) -> dict:
return {
"info": {
"config": {
"model": self.config.model_dump(mode="json"),
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
},
}
}
def _calculate_cost(self, response) -> dict[str, float]:
response_for_cost_calc = response.model_copy()
if self.config.litellm_model_name_override:
if response_for_cost_calc.model:
response_for_cost_calc.model = self.config.litellm_model_name_override
prompt_tokens = response_for_cost_calc.usage.prompt_tokens
if prompt_tokens is None:
logger.warning(
f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
)
prompt_tokens = 0
total_tokens = response_for_cost_calc.usage.total_tokens
completion_tokens = response_for_cost_calc.usage.completion_tokens
if completion_tokens is None:
logger.warning(
f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
)
completion_tokens = 0
if total_tokens - prompt_tokens - completion_tokens != 0:
# This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
logger.warning(
f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
" This is probably a portkey bug or incompatibility with litellm cost tracking. "
"Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
f"Full response: {response_for_cost_calc.model_dump()}"
)
response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
try:
cost = litellm.cost_calculator.completion_cost(
response_for_cost_calc, model=self.config.litellm_model_name_override or None
)
assert cost >= 0.0, f"Cost is negative: {cost}"
except Exception as e:
cost = 0.0
if self.config.cost_tracking != "ignore_errors":
msg = (
f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
"You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
"globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
"Alternatively check the 'Cost tracking' section in the documentation at "
"https://klieret.short.gy/mini-local-models. "
"Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
)
logger.critical(msg)
raise RuntimeError(msg) from e
return {"cost": cost}
Guide
Setting up Portkey models is covered in the quickstart guide.
minisweagent.models.portkey_model
logger
module-attribute
logger = getLogger('portkey_model')
PortkeyModelConfig
Bases: BaseModel
model_name
instance-attribute
model_name: str
model_kwargs
class-attribute
instance-attribute
model_kwargs: dict[str, Any] = {}
provider
class-attribute
instance-attribute
provider: str = ''
The LLM provider to use (e.g., 'openai', 'anthropic', 'google'). If not specified, will be auto-detected from model_name. Required by Portkey when not using a virtual key.
litellm_model_registry
class-attribute
instance-attribute
litellm_model_registry: Path | str | None = getenv(
"LITELLM_MODEL_REGISTRY_PATH"
)
We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry. Note that this might change if we get better support for Portkey and change how we calculate costs.
litellm_model_name_override
class-attribute
instance-attribute
litellm_model_name_override: str = ''
We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it doesn't match the Portkey model name. Note that this might change if we get better support for Portkey and change how we calculate costs.
set_cache_control
class-attribute
instance-attribute
set_cache_control: Literal['default_end'] | None = None
Set explicit cache control markers, for example for Anthropic models
cost_tracking
class-attribute
instance-attribute
cost_tracking: Literal["default", "ignore_errors"] = getenv(
"MSWEA_COST_TRACKING", "default"
)
Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)
format_error_template
class-attribute
instance-attribute
format_error_template: str = '{{ error }}'
Template used when the LM's output is not in the expected format.
observation_template
class-attribute
instance-attribute
observation_template: str = "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
Template used to render the observation after executing an action.
multimodal_regex
class-attribute
instance-attribute
multimodal_regex: str = ''
Regex to extract multimodal content. Empty string disables multimodal processing.
PortkeyModel
PortkeyModel(
*, config_class: type = PortkeyModelConfig, **kwargs
)
Source code in src/minisweagent/models/portkey_model.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | |
abort_exceptions
class-attribute
instance-attribute
abort_exceptions: list[type[Exception]] = [
KeyboardInterrupt,
TypeError,
ValueError,
]
config
instance-attribute
config = config_class(**kwargs)
client
instance-attribute
client = Portkey(**client_kwargs)
query
query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/portkey_model.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 | |
format_message
format_message(**kwargs) -> dict
Source code in src/minisweagent/models/portkey_model.py
123 124 | |
format_observation_messages
format_observation_messages(
message: dict,
outputs: list[dict],
template_vars: dict | None = None,
) -> list[dict]
Format execution outputs into tool result messages.
Source code in src/minisweagent/models/portkey_model.py
126 127 128 129 130 131 132 133 134 135 136 137 | |
get_template_vars
get_template_vars(**kwargs) -> dict[str, Any]
Source code in src/minisweagent/models/portkey_model.py
139 140 | |
serialize
serialize() -> dict
Source code in src/minisweagent/models/portkey_model.py
142 143 144 145 146 147 148 149 150 | |