Skip to content

Agent control flow

Understanding AI agent basics

We also recently created a long tutorial on understanding the basics of building an AI agent: View it here.

Understanding the default agent

  • This guide shows the control flow of the default agent.
  • After this, you're ready to remix & extend mini

The following diagram shows the control flow of the mini agent:

flowchart TD
    subgraph run["<b><code>Agent.run(task)</code></b>"]
        direction TB
        A["<b><code>Initialize messages</code></b>"] --> B
        B["<b><code>Agent.step</code></b>"] --> C{"<b><code>Exception?</code></b>"}
        C -->|Yes| D["<b><code>Agent.add_messages</code></b><br/>(also re-raises exceptions that don't inherit from InterruptAgentFlow)"]
        C -->|No| E{"<b><code>messages[-1].role == exit?</code></b>"}
        D --> E
        E -->|No| B
        E -->|Yes| F["<b><code>Return result</code></b>"]
    end

    subgraph step["<b><code>Agent.step()</code></b><br>Single iteration</br>"]
        direction TB
        S1["<b><code>Agent.query</code></b>"] --> S2["<b><code>Agent.execute_actions</code></b>"]
    end

    subgraph query["<b><code>Agent.query()</code></b><br>Also checks for cost limits</br><br></br>"]
        direction TB
        Q3["<b><code>Model.query</code></b>"] --> Q4["<b><code>Agent.add_messages</code></b>"]
    end

    subgraph execute_actions["<b><code>Agent.execute_actions(message)</code></b>"]
        direction TB
        E2["<b><code>Environment.execute</code></b><br/>Also raises the Submitted exception if we're done"] --> E3["<b><code>Model.format_observation_messages</code></b>"]
        E3 --> E4["<b><code>Agent.add_messages</code></b>"]
    end

    B -.-> step
    S1 -.-> query
    S2 -.-> execute_actions

And here is the code that implements it:

Default agent class
"""Basic agent class. See https://mini-swe-agent.com/latest/advanced/control_flow/ for visual explanation
or https://minimal-agent.com for a tutorial on the basic building principles.
"""

import json
import logging
import traceback
from pathlib import Path

from jinja2 import StrictUndefined, Template
from pydantic import BaseModel

from minisweagent import Environment, Model, __version__
from minisweagent.exceptions import InterruptAgentFlow, LimitsExceeded
from minisweagent.utils.serialize import recursive_merge


class AgentConfig(BaseModel):
    """Check the config files in minisweagent/config for example settings."""

    system_template: str
    """Template for the system message (the first message)."""
    instance_template: str
    """Template for the first user message specifying the task (the second message overall)."""
    step_limit: int = 0
    """Maximum number of steps the agent can take."""
    cost_limit: float = 3.0
    """Stop agent after exceeding (!) this cost."""
    output_path: Path | None = None
    """Save the trajectory to this path."""


class DefaultAgent:
    def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
        """See the `AgentConfig` class for permitted keyword arguments."""
        self.config = config_class(**kwargs)
        self.messages: list[dict] = []
        self.model = model
        self.env = env
        self.extra_template_vars = {}
        self.logger = logging.getLogger("agent")
        self.cost = 0.0
        self.n_calls = 0

    def get_template_vars(self, **kwargs) -> dict:
        return recursive_merge(
            self.config.model_dump(),
            self.env.get_template_vars(),
            self.model.get_template_vars(),
            {"n_model_calls": self.n_calls, "model_cost": self.cost},
            self.extra_template_vars,
            kwargs,
        )

    def _render_template(self, template: str) -> str:
        return Template(template, undefined=StrictUndefined).render(**self.get_template_vars())

    def add_messages(self, *messages: dict) -> list[dict]:
        self.logger.debug(messages)  # set log level to debug to see
        self.messages.extend(messages)
        return list(messages)

    def handle_uncaught_exception(self, e: Exception) -> list[dict]:
        return self.add_messages(
            self.model.format_message(
                role="exit",
                content=str(e),
                extra={
                    "exit_status": type(e).__name__,
                    "submission": "",
                    "exception_str": str(e),
                    "traceback": traceback.format_exc(),
                },
            )
        )

    def run(self, task: str = "", **kwargs) -> dict:
        """Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
        self.extra_template_vars |= {"task": task, **kwargs}
        self.messages = []
        self.add_messages(
            self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
            self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
        )
        while True:
            try:
                self.step()
            except InterruptAgentFlow as e:
                self.add_messages(*e.messages)
            except Exception as e:
                self.handle_uncaught_exception(e)
                raise
            finally:
                self.save(self.config.output_path)
            if self.messages[-1].get("role") == "exit":
                break
        return self.messages[-1].get("extra", {})

    def step(self) -> list[dict]:
        """Query the LM, execute actions."""
        return self.execute_actions(self.query())

    def query(self) -> dict:
        """Query the model and return model messages. Override to add hooks."""
        if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
            raise LimitsExceeded(
                {
                    "role": "exit",
                    "content": "LimitsExceeded",
                    "extra": {"exit_status": "LimitsExceeded", "submission": ""},
                }
            )
        self.n_calls += 1
        message = self.model.query(self.messages)
        self.cost += message.get("extra", {}).get("cost", 0.0)
        self.add_messages(message)
        return message

    def execute_actions(self, message: dict) -> list[dict]:
        """Execute actions in message, add observation messages, return them."""
        outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
        return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))

    def serialize(self, *extra_dicts) -> dict:
        """Serialize agent state to a json-compatible nested dictionary for saving."""
        last_message = self.messages[-1] if self.messages else {}
        last_extra = last_message.get("extra", {})
        agent_data = {
            "info": {
                "model_stats": {
                    "instance_cost": self.cost,
                    "api_calls": self.n_calls,
                },
                "config": {
                    "agent": self.config.model_dump(mode="json"),
                    "agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
                },
                "mini_version": __version__,
                "exit_status": last_extra.get("exit_status", ""),
                "submission": last_extra.get("submission", ""),
            },
            "messages": self.messages,
            "trajectory_format": "mini-swe-agent-1.1",
        }
        return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)

    def save(self, path: Path | None, *extra_dicts) -> dict:
        """Save the trajectory of the agent to a file if path is given. Returns full serialized data.
        You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
        """
        data = self.serialize(*extra_dicts)
        if path:
            path.parent.mkdir(parents=True, exist_ok=True)
            path.write_text(json.dumps(data, indent=2))
        return data

Essentially, DefaultAgent.run calls DefaultAgent.step in a loop until the agent has finished its task.

The step method is the core of the agent:

def step(self) -> list[dict]:
    return self.execute_actions(self.query())

It does the following:

  1. Queries the model for a response based on the current messages (DefaultAgent.query, calling Model.query)
  2. Executes all actions in the response (DefaultAgent.execute_actions, calling Environment.execute for each action)
  3. Formats the observation messages via Model.format_observation_messages
  4. Adds the observations to the messages

Here's query:

def query(self) -> dict:
    # ... limit checks ...
    message = self.model.query(self.messages)
    self.add_messages(message)
    return message

And execute_actions:

def execute_actions(self, message: dict) -> list[dict]:
    outputs = [self.env.execute(action) for action in message.get...
    return self.add_messages(*self.model.format_observation_messages(...))

The interesting bit is how we handle error conditions and the finish condition: This uses exceptions that inherit from InterruptAgentFlow. All these exceptions carry messages that get added to the trajectory.

  • Submitted is raised when the agent has finished its task. For example, the environment checks if the command output starts with a magic string:

    # In Environment.execute
    def _check_finished(self, output: dict):
        lines = output.get("output", "").lstrip().splitlines(keepends=True)
        if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT":
            raise Submitted({"role": "exit", "content": ..., "extra": {...}})
    
  • LimitsExceeded is raised when we hit a cost or step limit

  • FormatError is raised when the output from the LM is not in the expected format
  • TimeoutError is raised when the action took too long to execute
  • UserInterruption is raised when the user interrupts the agent

The DefaultAgent.run method catches these exceptions and handles them by adding the corresponding messages to the messages list. The loop continues until a message with role="exit" is added.

while True:
    try:
        self.step()
    except InterruptAgentFlow as e:
        self.add_messages(*e.messages)
    if self.messages[-1].get("role") == "exit":
        break

Using exceptions for the control flow is a lot easier than passing around flags and states, especially when extending or subclassing the agent.