Skip to content

DefaultAgent

DefaultAgent class

Full source code
"""Basic agent class. See https://mini-swe-agent.com/latest/advanced/control_flow/ for visual explanation
or https://minimal-agent.com for a tutorial on the basic building principles.
"""

import json
import logging
import traceback
from pathlib import Path

from jinja2 import StrictUndefined, Template
from pydantic import BaseModel

from minisweagent import Environment, Model, __version__
from minisweagent.exceptions import InterruptAgentFlow, LimitsExceeded
from minisweagent.utils.serialize import recursive_merge


class AgentConfig(BaseModel):
    """Check the config files in minisweagent/config for example settings."""

    system_template: str
    """Template for the system message (the first message)."""
    instance_template: str
    """Template for the first user message specifying the task (the second message overall)."""
    step_limit: int = 0
    """Maximum number of steps the agent can take."""
    cost_limit: float = 3.0
    """Stop agent after exceeding (!) this cost."""
    output_path: Path | None = None
    """Save the trajectory to this path."""


class DefaultAgent:
    def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
        """See the `AgentConfig` class for permitted keyword arguments."""
        self.config = config_class(**kwargs)
        self.messages: list[dict] = []
        self.model = model
        self.env = env
        self.extra_template_vars = {}
        self.logger = logging.getLogger("agent")
        self.cost = 0.0
        self.n_calls = 0

    def get_template_vars(self, **kwargs) -> dict:
        return recursive_merge(
            self.config.model_dump(),
            self.env.get_template_vars(),
            self.model.get_template_vars(),
            {"n_model_calls": self.n_calls, "model_cost": self.cost},
            self.extra_template_vars,
            kwargs,
        )

    def _render_template(self, template: str) -> str:
        return Template(template, undefined=StrictUndefined).render(**self.get_template_vars())

    def add_messages(self, *messages: dict) -> list[dict]:
        self.logger.debug(messages)  # set log level to debug to see
        self.messages.extend(messages)
        return list(messages)

    def handle_uncaught_exception(self, e: Exception) -> list[dict]:
        return self.add_messages(
            self.model.format_message(
                role="exit",
                content=str(e),
                extra={
                    "exit_status": type(e).__name__,
                    "submission": "",
                    "exception_str": str(e),
                    "traceback": traceback.format_exc(),
                },
            )
        )

    def run(self, task: str = "", **kwargs) -> dict:
        """Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
        self.extra_template_vars |= {"task": task, **kwargs}
        self.messages = []
        self.add_messages(
            self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
            self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
        )
        while True:
            try:
                self.step()
            except InterruptAgentFlow as e:
                self.add_messages(*e.messages)
            except Exception as e:
                self.handle_uncaught_exception(e)
                raise
            finally:
                self.save(self.config.output_path)
            if self.messages[-1].get("role") == "exit":
                break
        return self.messages[-1].get("extra", {})

    def step(self) -> list[dict]:
        """Query the LM, execute actions."""
        return self.execute_actions(self.query())

    def query(self) -> dict:
        """Query the model and return model messages. Override to add hooks."""
        if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
            raise LimitsExceeded(
                {
                    "role": "exit",
                    "content": "LimitsExceeded",
                    "extra": {"exit_status": "LimitsExceeded", "submission": ""},
                }
            )
        self.n_calls += 1
        message = self.model.query(self.messages)
        self.cost += message.get("extra", {}).get("cost", 0.0)
        self.add_messages(message)
        return message

    def execute_actions(self, message: dict) -> list[dict]:
        """Execute actions in message, add observation messages, return them."""
        outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
        return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))

    def serialize(self, *extra_dicts) -> dict:
        """Serialize agent state to a json-compatible nested dictionary for saving."""
        last_message = self.messages[-1] if self.messages else {}
        last_extra = last_message.get("extra", {})
        agent_data = {
            "info": {
                "model_stats": {
                    "instance_cost": self.cost,
                    "api_calls": self.n_calls,
                },
                "config": {
                    "agent": self.config.model_dump(mode="json"),
                    "agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
                },
                "mini_version": __version__,
                "exit_status": last_extra.get("exit_status", ""),
                "submission": last_extra.get("submission", ""),
            },
            "messages": self.messages,
            "trajectory_format": "mini-swe-agent-1.1",
        }
        return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)

    def save(self, path: Path | None, *extra_dicts) -> dict:
        """Save the trajectory of the agent to a file if path is given. Returns full serialized data.
        You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
        """
        data = self.serialize(*extra_dicts)
        if path:
            path.parent.mkdir(parents=True, exist_ok=True)
            path.write_text(json.dumps(data, indent=2))
        return data

Understanding the control flow

Check out the control flow guide for a visual explanation of the agent's control flow.

minisweagent.agents.default.AgentConfig

Bases: BaseModel

Check the config files in minisweagent/config for example settings.

system_template instance-attribute

system_template: str

Template for the system message (the first message).

instance_template instance-attribute

instance_template: str

Template for the first user message specifying the task (the second message overall).

step_limit class-attribute instance-attribute

step_limit: int = 0

Maximum number of steps the agent can take.

cost_limit class-attribute instance-attribute

cost_limit: float = 3.0

Stop agent after exceeding (!) this cost.

output_path class-attribute instance-attribute

output_path: Path | None = None

Save the trajectory to this path.

minisweagent.agents.default.DefaultAgent

DefaultAgent(
    model: Model,
    env: Environment,
    *,
    config_class: type = AgentConfig,
    **kwargs,
)

See the AgentConfig class for permitted keyword arguments.

Source code in src/minisweagent/agents/default.py
34
35
36
37
38
39
40
41
42
43
def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
    """See the `AgentConfig` class for permitted keyword arguments."""
    self.config = config_class(**kwargs)
    self.messages: list[dict] = []
    self.model = model
    self.env = env
    self.extra_template_vars = {}
    self.logger = logging.getLogger("agent")
    self.cost = 0.0
    self.n_calls = 0

config instance-attribute

config = config_class(**kwargs)

messages instance-attribute

messages: list[dict] = []

model instance-attribute

model = model

env instance-attribute

env = env

extra_template_vars instance-attribute

extra_template_vars = {}

logger instance-attribute

logger = getLogger('agent')

cost instance-attribute

cost = 0.0

n_calls instance-attribute

n_calls = 0

get_template_vars

get_template_vars(**kwargs) -> dict
Source code in src/minisweagent/agents/default.py
45
46
47
48
49
50
51
52
53
def get_template_vars(self, **kwargs) -> dict:
    return recursive_merge(
        self.config.model_dump(),
        self.env.get_template_vars(),
        self.model.get_template_vars(),
        {"n_model_calls": self.n_calls, "model_cost": self.cost},
        self.extra_template_vars,
        kwargs,
    )

add_messages

add_messages(*messages: dict) -> list[dict]
Source code in src/minisweagent/agents/default.py
58
59
60
61
def add_messages(self, *messages: dict) -> list[dict]:
    self.logger.debug(messages)  # set log level to debug to see
    self.messages.extend(messages)
    return list(messages)

handle_uncaught_exception

handle_uncaught_exception(e: Exception) -> list[dict]
Source code in src/minisweagent/agents/default.py
63
64
65
66
67
68
69
70
71
72
73
74
75
def handle_uncaught_exception(self, e: Exception) -> list[dict]:
    return self.add_messages(
        self.model.format_message(
            role="exit",
            content=str(e),
            extra={
                "exit_status": type(e).__name__,
                "submission": "",
                "exception_str": str(e),
                "traceback": traceback.format_exc(),
            },
        )
    )

run

run(task: str = '', **kwargs) -> dict

Run step() until agent is finished. Returns dictionary with exit_status, submission keys.

Source code in src/minisweagent/agents/default.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def run(self, task: str = "", **kwargs) -> dict:
    """Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
    self.extra_template_vars |= {"task": task, **kwargs}
    self.messages = []
    self.add_messages(
        self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
        self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
    )
    while True:
        try:
            self.step()
        except InterruptAgentFlow as e:
            self.add_messages(*e.messages)
        except Exception as e:
            self.handle_uncaught_exception(e)
            raise
        finally:
            self.save(self.config.output_path)
        if self.messages[-1].get("role") == "exit":
            break
    return self.messages[-1].get("extra", {})

step

step() -> list[dict]

Query the LM, execute actions.

Source code in src/minisweagent/agents/default.py
 99
100
101
def step(self) -> list[dict]:
    """Query the LM, execute actions."""
    return self.execute_actions(self.query())

query

query() -> dict

Query the model and return model messages. Override to add hooks.

Source code in src/minisweagent/agents/default.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def query(self) -> dict:
    """Query the model and return model messages. Override to add hooks."""
    if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
        raise LimitsExceeded(
            {
                "role": "exit",
                "content": "LimitsExceeded",
                "extra": {"exit_status": "LimitsExceeded", "submission": ""},
            }
        )
    self.n_calls += 1
    message = self.model.query(self.messages)
    self.cost += message.get("extra", {}).get("cost", 0.0)
    self.add_messages(message)
    return message

execute_actions

execute_actions(message: dict) -> list[dict]

Execute actions in message, add observation messages, return them.

Source code in src/minisweagent/agents/default.py
119
120
121
122
def execute_actions(self, message: dict) -> list[dict]:
    """Execute actions in message, add observation messages, return them."""
    outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
    return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))

serialize

serialize(*extra_dicts) -> dict

Serialize agent state to a json-compatible nested dictionary for saving.

Source code in src/minisweagent/agents/default.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def serialize(self, *extra_dicts) -> dict:
    """Serialize agent state to a json-compatible nested dictionary for saving."""
    last_message = self.messages[-1] if self.messages else {}
    last_extra = last_message.get("extra", {})
    agent_data = {
        "info": {
            "model_stats": {
                "instance_cost": self.cost,
                "api_calls": self.n_calls,
            },
            "config": {
                "agent": self.config.model_dump(mode="json"),
                "agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
            },
            "mini_version": __version__,
            "exit_status": last_extra.get("exit_status", ""),
            "submission": last_extra.get("submission", ""),
        },
        "messages": self.messages,
        "trajectory_format": "mini-swe-agent-1.1",
    }
    return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)

save

save(path: Path | None, *extra_dicts) -> dict

Save the trajectory of the agent to a file if path is given. Returns full serialized data. You can pass additional dictionaries with extra data to be (recursively) merged into the output data.

Source code in src/minisweagent/agents/default.py
147
148
149
150
151
152
153
154
155
def save(self, path: Path | None, *extra_dicts) -> dict:
    """Save the trajectory of the agent to a file if path is given. Returns full serialized data.
    You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
    """
    data = self.serialize(*extra_dicts)
    if path:
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(json.dumps(data, indent=2))
    return data

minisweagent.exceptions.InterruptAgentFlow

InterruptAgentFlow(*messages: dict)

Bases: Exception

Raised to interrupt the agent flow and add messages.

Source code in src/minisweagent/exceptions.py
4
5
6
def __init__(self, *messages: dict):
    self.messages = messages
    super().__init__()

messages instance-attribute

messages = messages

minisweagent.exceptions.Submitted

Submitted(*messages: dict)

Bases: InterruptAgentFlow

Raised when the agent has completed its task.

Source code in src/minisweagent/exceptions.py
4
5
6
def __init__(self, *messages: dict):
    self.messages = messages
    super().__init__()

minisweagent.exceptions.LimitsExceeded

LimitsExceeded(*messages: dict)

Bases: InterruptAgentFlow

Raised when the agent has exceeded its cost or step limit.

Source code in src/minisweagent/exceptions.py
4
5
6
def __init__(self, *messages: dict):
    self.messages = messages
    super().__init__()

minisweagent.exceptions.FormatError

FormatError(*messages: dict)

Bases: InterruptAgentFlow

Raised when the LM's output is not in the expected format.

Source code in src/minisweagent/exceptions.py
4
5
6
def __init__(self, *messages: dict):
    self.messages = messages
    super().__init__()

minisweagent.exceptions.UserInterruption

UserInterruption(*messages: dict)

Bases: InterruptAgentFlow

Raised when the user interrupts the agent.

Source code in src/minisweagent/exceptions.py
4
5
6
def __init__(self, *messages: dict):
    self.messages = messages
    super().__init__()