Skip to content

Test Models

Test Models class

Full source code
import logging
import time
from dataclasses import asdict, dataclass
from typing import Any

from minisweagent.models import GLOBAL_MODEL_STATS


@dataclass
class DeterministicModelConfig:
    outputs: list[str]
    model_name: str = "deterministic"
    cost_per_call: float = 1.0


class DeterministicModel:
    def __init__(self, **kwargs):
        """
        Initialize with a list of outputs to return in sequence.
        """
        self.config = DeterministicModelConfig(**kwargs)
        self.current_index = -1
        self.cost = 0.0
        self.n_calls = 0

    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
        self.current_index += 1
        output = self.config.outputs[self.current_index]
        if "/sleep" in output:
            print("SLEEPING")
            time.sleep(float(output.split("/sleep")[1]))
            return self.query(messages, **kwargs)
        if "/warning" in output:
            logging.warning(output.split("/warning")[1])
            return self.query(messages, **kwargs)
        self.n_calls += 1
        self.cost += self.config.cost_per_call
        GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
        return {"content": output}

    def get_template_vars(self) -> dict[str, Any]:
        return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}

minisweagent.models.test_models

DeterministicModelConfig dataclass

DeterministicModelConfig(
    outputs: list[str],
    model_name: str = "deterministic",
    cost_per_call: float = 1.0,
)

outputs instance-attribute

outputs: list[str]

model_name class-attribute instance-attribute

model_name: str = 'deterministic'

cost_per_call class-attribute instance-attribute

cost_per_call: float = 1.0

DeterministicModel

DeterministicModel(**kwargs)

Initialize with a list of outputs to return in sequence.

Source code in src/minisweagent/models/test_models.py
17
18
19
20
21
22
23
24
def __init__(self, **kwargs):
    """
    Initialize with a list of outputs to return in sequence.
    """
    self.config = DeterministicModelConfig(**kwargs)
    self.current_index = -1
    self.cost = 0.0
    self.n_calls = 0

config instance-attribute

config = DeterministicModelConfig(**kwargs)

current_index instance-attribute

current_index = -1

cost instance-attribute

cost = 0.0

n_calls instance-attribute

n_calls = 0

query

query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/test_models.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
    self.current_index += 1
    output = self.config.outputs[self.current_index]
    if "/sleep" in output:
        print("SLEEPING")
        time.sleep(float(output.split("/sleep")[1]))
        return self.query(messages, **kwargs)
    if "/warning" in output:
        logging.warning(output.split("/warning")[1])
        return self.query(messages, **kwargs)
    self.n_calls += 1
    self.cost += self.config.cost_per_call
    GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
    return {"content": output}

get_template_vars

get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/test_models.py
41
42
def get_template_vars(self) -> dict[str, Any]:
    return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}