DefaultAgent
DefaultAgent class
Full source code
"""Basic agent class. See https://mini-swe-agent.com/latest/advanced/control_flow/ for visual explanation
or https://minimal-agent.com for a tutorial on the basic building principles.
"""
import json
import logging
import traceback
from pathlib import Path
from jinja2 import StrictUndefined, Template
from pydantic import BaseModel
from minisweagent import Environment, Model, __version__
from minisweagent.exceptions import InterruptAgentFlow, LimitsExceeded
from minisweagent.utils.serialize import recursive_merge
class AgentConfig(BaseModel):
"""Check the config files in minisweagent/config for example settings."""
system_template: str
"""Template for the system message (the first message)."""
instance_template: str
"""Template for the first user message specifying the task (the second message overall)."""
step_limit: int = 0
"""Maximum number of steps the agent can take."""
cost_limit: float = 3.0
"""Stop agent after exceeding (!) this cost."""
output_path: Path | None = None
"""Save the trajectory to this path."""
class DefaultAgent:
def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
"""See the `AgentConfig` class for permitted keyword arguments."""
self.config = config_class(**kwargs)
self.messages: list[dict] = []
self.model = model
self.env = env
self.extra_template_vars = {}
self.logger = logging.getLogger("agent")
self.cost = 0.0
self.n_calls = 0
def get_template_vars(self, **kwargs) -> dict:
return recursive_merge(
self.config.model_dump(),
self.env.get_template_vars(),
self.model.get_template_vars(),
{"n_model_calls": self.n_calls, "model_cost": self.cost},
self.extra_template_vars,
kwargs,
)
def _render_template(self, template: str) -> str:
return Template(template, undefined=StrictUndefined).render(**self.get_template_vars())
def add_messages(self, *messages: dict) -> list[dict]:
self.logger.debug(messages) # set log level to debug to see
self.messages.extend(messages)
return list(messages)
def handle_uncaught_exception(self, e: Exception) -> list[dict]:
return self.add_messages(
self.model.format_message(
role="exit",
content=str(e),
extra={
"exit_status": type(e).__name__,
"submission": "",
"exception_str": str(e),
"traceback": traceback.format_exc(),
},
)
)
def run(self, task: str = "", **kwargs) -> dict:
"""Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
self.extra_template_vars |= {"task": task, **kwargs}
self.messages = []
self.add_messages(
self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
)
while True:
try:
self.step()
except InterruptAgentFlow as e:
self.add_messages(*e.messages)
except Exception as e:
self.handle_uncaught_exception(e)
raise
finally:
self.save(self.config.output_path)
if self.messages[-1].get("role") == "exit":
break
return self.messages[-1].get("extra", {})
def step(self) -> list[dict]:
"""Query the LM, execute actions."""
return self.execute_actions(self.query())
def query(self) -> dict:
"""Query the model and return model messages. Override to add hooks."""
if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
raise LimitsExceeded(
{
"role": "exit",
"content": "LimitsExceeded",
"extra": {"exit_status": "LimitsExceeded", "submission": ""},
}
)
self.n_calls += 1
message = self.model.query(self.messages)
self.cost += message.get("extra", {}).get("cost", 0.0)
self.add_messages(message)
return message
def execute_actions(self, message: dict) -> list[dict]:
"""Execute actions in message, add observation messages, return them."""
outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))
def serialize(self, *extra_dicts) -> dict:
"""Serialize agent state to a json-compatible nested dictionary for saving."""
last_message = self.messages[-1] if self.messages else {}
last_extra = last_message.get("extra", {})
agent_data = {
"info": {
"model_stats": {
"instance_cost": self.cost,
"api_calls": self.n_calls,
},
"config": {
"agent": self.config.model_dump(mode="json"),
"agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
},
"mini_version": __version__,
"exit_status": last_extra.get("exit_status", ""),
"submission": last_extra.get("submission", ""),
},
"messages": self.messages,
"trajectory_format": "mini-swe-agent-1.1",
}
return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)
def save(self, path: Path | None, *extra_dicts) -> dict:
"""Save the trajectory of the agent to a file if path is given. Returns full serialized data.
You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
"""
data = self.serialize(*extra_dicts)
if path:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2))
return data
Understanding the control flow
Check out the control flow guide for a visual explanation of the agent's control flow.
minisweagent.agents.default.AgentConfig
Bases: BaseModel
Check the config files in minisweagent/config for example settings.
system_template
instance-attribute
system_template: str
Template for the system message (the first message).
instance_template
instance-attribute
instance_template: str
Template for the first user message specifying the task (the second message overall).
step_limit
class-attribute
instance-attribute
step_limit: int = 0
Maximum number of steps the agent can take.
cost_limit
class-attribute
instance-attribute
cost_limit: float = 3.0
Stop agent after exceeding (!) this cost.
output_path
class-attribute
instance-attribute
output_path: Path | None = None
Save the trajectory to this path.
minisweagent.agents.default.DefaultAgent
DefaultAgent(
model: Model,
env: Environment,
*,
config_class: type = AgentConfig,
**kwargs,
)
See the AgentConfig class for permitted keyword arguments.
Source code in src/minisweagent/agents/default.py
34 35 36 37 38 39 40 41 42 43 | |
config
instance-attribute
config = config_class(**kwargs)
messages
instance-attribute
messages: list[dict] = []
model
instance-attribute
model = model
env
instance-attribute
env = env
extra_template_vars
instance-attribute
extra_template_vars = {}
logger
instance-attribute
logger = getLogger('agent')
cost
instance-attribute
cost = 0.0
n_calls
instance-attribute
n_calls = 0
get_template_vars
get_template_vars(**kwargs) -> dict
Source code in src/minisweagent/agents/default.py
45 46 47 48 49 50 51 52 53 | |
add_messages
add_messages(*messages: dict) -> list[dict]
Source code in src/minisweagent/agents/default.py
58 59 60 61 | |
handle_uncaught_exception
handle_uncaught_exception(e: Exception) -> list[dict]
Source code in src/minisweagent/agents/default.py
63 64 65 66 67 68 69 70 71 72 73 74 75 | |
run
run(task: str = '', **kwargs) -> dict
Run step() until agent is finished. Returns dictionary with exit_status, submission keys.
Source code in src/minisweagent/agents/default.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
step
step() -> list[dict]
Query the LM, execute actions.
Source code in src/minisweagent/agents/default.py
99 100 101 | |
query
query() -> dict
Query the model and return model messages. Override to add hooks.
Source code in src/minisweagent/agents/default.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | |
execute_actions
execute_actions(message: dict) -> list[dict]
Execute actions in message, add observation messages, return them.
Source code in src/minisweagent/agents/default.py
119 120 121 122 | |
serialize
serialize(*extra_dicts) -> dict
Serialize agent state to a json-compatible nested dictionary for saving.
Source code in src/minisweagent/agents/default.py
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | |
save
save(path: Path | None, *extra_dicts) -> dict
Save the trajectory of the agent to a file if path is given. Returns full serialized data. You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
Source code in src/minisweagent/agents/default.py
147 148 149 150 151 152 153 154 155 | |
minisweagent.exceptions.InterruptAgentFlow
InterruptAgentFlow(*messages: dict)
Bases: Exception
Raised to interrupt the agent flow and add messages.
Source code in src/minisweagent/exceptions.py
4 5 6 | |
messages
instance-attribute
messages = messages
minisweagent.exceptions.Submitted
Submitted(*messages: dict)
Bases: InterruptAgentFlow
Raised when the agent has completed its task.
Source code in src/minisweagent/exceptions.py
4 5 6 | |
minisweagent.exceptions.LimitsExceeded
LimitsExceeded(*messages: dict)
Bases: InterruptAgentFlow
Raised when the agent has exceeded its cost or step limit.
Source code in src/minisweagent/exceptions.py
4 5 6 | |
minisweagent.exceptions.FormatError
FormatError(*messages: dict)
Bases: InterruptAgentFlow
Raised when the LM's output is not in the expected format.
Source code in src/minisweagent/exceptions.py
4 5 6 | |
minisweagent.exceptions.UserInterruption
UserInterruption(*messages: dict)
Bases: InterruptAgentFlow
Raised when the user interrupts the agent.
Source code in src/minisweagent/exceptions.py
4 5 6 | |