Test Models
Test Models class
Full source code
import logging
import time
from dataclasses import asdict, dataclass
from typing import Any
from minisweagent.models import GLOBAL_MODEL_STATS
@dataclass
class DeterministicModelConfig:
outputs: list[str]
model_name: str = "deterministic"
cost_per_call: float = 1.0
class DeterministicModel:
def __init__(self, **kwargs):
"""
Initialize with a list of outputs to return in sequence.
"""
self.config = DeterministicModelConfig(**kwargs)
self.current_index = -1
self.cost = 0.0
self.n_calls = 0
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
self.current_index += 1
output = self.config.outputs[self.current_index]
if "/sleep" in output:
print("SLEEPING")
time.sleep(float(output.split("/sleep")[1]))
return self.query(messages, **kwargs)
if "/warning" in output:
logging.warning(output.split("/warning")[1])
return self.query(messages, **kwargs)
self.n_calls += 1
self.cost += self.config.cost_per_call
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
return {"content": output}
def get_template_vars(self) -> dict[str, Any]:
return asdict(self.config) | {"n_model_calls": self.n_calls, "model_cost": self.cost}
minisweagent.models.test_models
DeterministicModelConfig
dataclass
DeterministicModelConfig(
outputs: list[str],
model_name: str = "deterministic",
cost_per_call: float = 1.0,
)
outputs
instance-attribute
outputs: list[str]
model_name
class-attribute
instance-attribute
model_name: str = 'deterministic'
cost_per_call
class-attribute
instance-attribute
cost_per_call: float = 1.0
DeterministicModel
DeterministicModel(**kwargs)
Initialize with a list of outputs to return in sequence.
Source code in src/minisweagent/models/test_models.py
17 18 19 20 21 22 23 24 | |
current_index
instance-attribute
current_index = -1
cost
instance-attribute
cost = 0.0
n_calls
instance-attribute
n_calls = 0
query
query(messages: list[dict[str, str]], **kwargs) -> dict
Source code in src/minisweagent/models/test_models.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
get_template_vars
get_template_vars() -> dict[str, Any]
Source code in src/minisweagent/models/test_models.py
41 42 | |