Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ async def reset(
request_max_retries: int | None = None,
tool_result_overflow_dir: str | None = None,
read_tool: FunctionTool | None = None,
overflow_file_writer: T.Callable[[str, str], T.Awaitable[str]] | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
Expand All @@ -241,6 +242,7 @@ async def reset(
self.request_max_retries = request_max_retries
self.tool_result_overflow_dir = tool_result_overflow_dir
self.read_tool = read_tool
self._overflow_file_writer = overflow_file_writer
self._tool_result_token_counter = EstimateTokenCounter()
self.request_context_manager_config = ContextConfig(
# <=0 disables token-based guarding.
Expand Down Expand Up @@ -369,6 +371,9 @@ async def _write_tool_result_overflow_file(
tool_call_id: str,
content: str,
) -> str:
if self._overflow_file_writer is not None:
return await self._overflow_file_writer(content, tool_call_id)

if self.tool_result_overflow_dir is None:
raise ValueError("tool_result_overflow_dir is not configured")

Expand Down
13 changes: 13 additions & 0 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,18 @@ async def build_main_agent(
elif config.computer_use_runtime == "local":
_apply_local_env_tools(req, plugin_context)

overflow_file_writer = None
if (
config.computer_use_runtime == "sandbox"
and req.func_tool
and req.func_tool.get_tool("astrbot_file_read_tool")
):
from astrbot.core.computer.computer_client import make_sandbox_overflow_writer

overflow_file_writer = make_sandbox_overflow_writer(
plugin_context, event.unified_msg_origin
)

agent_runner = AgentRunner()
astr_agent_ctx = AstrAgentContext(
context=plugin_context,
Expand Down Expand Up @@ -1625,6 +1637,7 @@ async def build_main_agent(
read_tool=(
req.func_tool.get_tool("astrbot_file_read_tool") if req.func_tool else None
),
overflow_file_writer=overflow_file_writer,
)

if apply_reset:
Expand Down
34 changes: 34 additions & 0 deletions astrbot/core/computer/computer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,40 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
logger.warning(f"Failed to remove temp skills zip: {zip_path}")


def make_sandbox_overflow_writer(
context: Context,
unified_msg_origin: str,
):
"""Build a callback that writes tool-result overflow content directly into the sandbox.

The returned callable has the signature
``(content: str, tool_call_id: str) -> Awaitable[str]`` and returns a
sandbox-relative path that ``astrbot_file_read_tool`` can resolve inside
the sandbox container.

Bay's filesystem API requires relative paths, so we write to a file under
the sandbox working directory rather than an absolute ``/tmp/...`` path.
"""

async def _write(content: str, tool_call_id: str) -> str:
safe_id = (
"".join(
ch if ch.isalnum() or ch in {"-", "_", "."} else "_"
for ch in tool_call_id
).strip("._")
or "tool_call"
)
sandbox_path = f"astrbot_overflow_{safe_id}_{uuid.uuid4().hex[:8]}.txt"
booter = await get_booter(context, unified_msg_origin)
Comment on lines +557 to +566

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider truncating the sanitized tool_call_id to avoid excessively long filenames in the sandbox.

If tool_call_id can be very long, safe_id may create filenames that exceed filesystem limits or be unwieldy in logs. Consider truncating safe_id (e.g., to 32–64 chars) before appending the UUID so filenames stay within reasonable bounds while remaining debuggable.

Suggested change
async def _write(content: str, tool_call_id: str) -> str:
safe_id = (
"".join(
ch if ch.isalnum() or ch in {"-", "_", "."} else "_"
for ch in tool_call_id
).strip("._")
or "tool_call"
)
sandbox_path = f"astrbot_overflow_{safe_id}_{uuid.uuid4().hex[:8]}.txt"
booter = await get_booter(context, unified_msg_origin)
async def _write(content: str, tool_call_id: str) -> str:
safe_id = (
"".join(
ch if ch.isalnum() or ch in {"-", "_", "."} else "_"
for ch in tool_call_id
).strip("._")
or "tool_call"
)
max_safe_id_len = 64
if len(safe_id) > max_safe_id_len:
safe_id = safe_id[:max_safe_id_len]
sandbox_path = f"astrbot_overflow_{safe_id}_{uuid.uuid4().hex[:8]}.txt"
booter = await get_booter(context, unified_msg_origin)

await booter.fs.write_file(sandbox_path, content)
logger.debug(
"[Computer] Overflow file written to sandbox: %s", sandbox_path
)
return sandbox_path

return _write


async def get_booter(
context: Context,
session_id: str,
Expand Down
16 changes: 16 additions & 0 deletions astrbot/core/star/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ async def tool_loop_agent(
other_kwargs.setdefault(
"read_tool", request.func_tool.get_tool("astrbot_file_read_tool")
)
if self._is_sandbox_runtime(event.unified_msg_origin):
from astrbot.core.computer.computer_client import (
make_sandbox_overflow_writer,
)

other_kwargs.setdefault(
"overflow_file_writer",
make_sandbox_overflow_writer(self, event.unified_msg_origin),
)

await agent_runner.reset(
provider=prov,
Expand Down Expand Up @@ -503,6 +512,13 @@ def get_config(self, umo: str | None = None) -> AstrBotConfig:
return self._config
return self.astrbot_config_mgr.get_conf(umo)

def _is_sandbox_runtime(self, umo: str) -> bool:
cfg = self.get_config(umo=umo)
runtime = str(
cfg.get("provider_settings", {}).get("computer_use_runtime", "local")
)
Comment on lines +516 to +519

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

如果配置中的 provider_settings 显式为 Nonecfg.get("provider_settings", {}) 将返回 None,随后调用 .get() 会抛出 AttributeError。建议使用 cfg.get("provider_settings") or {} 进行防御性保护。

        cfg = self.get_config(umo=umo)
        provider_settings = cfg.get("provider_settings") or {}
        runtime = str(provider_settings.get("computer_use_runtime", "local"))

return runtime == "sandbox"

async def send_message(
self,
session: str | MessageSesion,
Expand Down
244 changes: 244 additions & 0 deletions tests/test_tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,250 @@ async def test_follow_up_after_stop_not_merged_into_tool_result(
assert ticket_before.resolved.is_set()


# ---------------------------------------------------------------------------
# Tests for tool-result overflow file writer (sandbox mode fix)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_overflow_file_writer_callback_is_used(tmp_path, monkeypatch):
"""When overflow_file_writer is provided, _write_tool_result_overflow_file
MUST delegate to the callback and return its result instead of writing to
tool_result_overflow_dir."""
tool = FunctionTool(
name="test_tool",
description="test",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
read_tool = FunctionTool(
name="astrbot_file_read_tool",
description="read file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool, read_tool])
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
runner = ToolLoopAgentRunner()

# A spy callback that records calls and returns a known sandbox path
call_records: list[dict] = []
expected_sandbox_path = "/tmp/astrbot_overflow_deadbeef.txt"

async def _spy_writer(content: str, tool_call_id: str) -> str:
call_records.append({"content": content, "tool_call_id": tool_call_id})
return expected_sandbox_path

await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=cast(
Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text())
),
agent_hooks=MockHooks(),
streaming=False,
tool_result_overflow_dir=str(tmp_path),
read_tool=read_tool,
overflow_file_writer=_spy_writer,
)

async for _ in runner.step_until_done(3):
pass

# Callback must have been called exactly once
assert len(call_records) == 1, (
f"Expected 1 callback invocation, got {len(call_records)}"
)
assert call_records[0]["tool_call_id"] == "call_large_result"

# The overflow notice in the tool message MUST contain the sandbox path
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
assert len(tool_messages) == 1
tool_message_content = str(tool_messages[0].content)
assert expected_sandbox_path in tool_message_content, (
f"Expected sandbox path '{expected_sandbox_path}' in notice, "
f"got: ...{tool_message_content[-200:]}"
)
assert "Truncated tool output preview shown above." in tool_message_content
assert "`astrbot_file_read_tool`" in tool_message_content

# The tool_result_overflow_dir directory MUST NOT be used
overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt"))
assert len(overflow_files) == 0, (
f"Callback was provided but file was still written to "
f"tool_result_overflow_dir: {overflow_files}"
)


@pytest.mark.asyncio
async def test_overflow_file_writer_none_uses_disk_fallback(tmp_path):
"""When overflow_file_writer is None (default), the existing disk-based
overflow path MUST work exactly as before — no regression."""
tool = FunctionTool(
name="test_tool",
description="test",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
read_tool = FunctionTool(
name="astrbot_file_read_tool",
description="read file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool, read_tool])
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
runner = ToolLoopAgentRunner()

await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=cast(
Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text())
),
agent_hooks=MockHooks(),
streaming=False,
tool_result_overflow_dir=str(tmp_path),
read_tool=read_tool,
# overflow_file_writer NOT passed — should default to None
)

async for _ in runner.step_until_done(3):
pass

# Disk-based overflow MUST still work
tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
assert len(tool_messages) == 1
tool_message_content = str(tool_messages[0].content)
assert "Truncated tool output preview shown above." in tool_message_content
assert "`astrbot_file_read_tool`" in tool_message_content

overflow_files = list(Path(tmp_path).glob("call_large_result_*.txt"))
assert len(overflow_files) == 1
assert (
overflow_files[0].read_text(encoding="utf-8") == _make_large_tool_result_text()
)


def test_make_sandbox_overflow_writer_returns_callable():
"""make_sandbox_overflow_writer must return an async callable."""
from astrbot.core.computer.computer_client import make_sandbox_overflow_writer

writer = make_sandbox_overflow_writer(
context=SimpleNamespace(), # type: ignore[arg-type]
unified_msg_origin="test_umo",
)
assert callable(writer)
assert asyncio.iscoroutinefunction(writer)


@pytest.mark.asyncio
async def test_make_sandbox_overflow_writer_writes_via_booter(monkeypatch):
"""The writer returned by make_sandbox_overflow_writer MUST write to the
sandbox filesystem via booter.fs.write_file and return a /tmp/ path."""
from astrbot.core.computer.computer_client import make_sandbox_overflow_writer

# Fake booter that records write_file calls
write_calls: list[dict] = []

class _FakeFS:
async def write_file(self, path: str, content: str) -> None:
write_calls.append({"path": path, "content": content})

class _FakeBooter:
def __init__(self):
self.fs = _FakeFS()

_fake_booter = _FakeBooter()

async def _fake_get_booter(context, umo):
return _fake_booter

monkeypatch.setattr(
"astrbot.core.computer.computer_client.get_booter",
_fake_get_booter,
)

writer = make_sandbox_overflow_writer(
context=SimpleNamespace(), # type: ignore[arg-type]
unified_msg_origin="test_umo",
)

result_path = await writer("hello sandbox", "call_abc123")

# Must return a /tmp/ path
assert result_path.startswith("/tmp/astrbot_overflow_"), (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (testing): Avoid hard-coding a /tmp/ prefix for sandbox overflow paths in this test

This assertion couples the test to a specific /tmp-based layout and will break if the sandbox implementation changes (as suggested in the PR description, which mentions returning a sandbox-relative path). Instead, assert the stable parts of the contract, e.g. that the result is a string, includes an astrbot_overflow_ prefix and .txt suffix, and possibly excludes unsafe characters, without requiring a /tmp/ root.

f"Expected sandbox /tmp/ path, got: {result_path}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

由于 make_sandbox_overflow_writer 的实现已经改为返回相对路径(不带 /tmp/ 前缀,如其文档字符串所述),此处的断言会失败。应该将断言修改为检查是否以 astrbot_overflow_ 开头。

Suggested change
assert result_path.startswith("/tmp/astrbot_overflow_"), (
f"Expected sandbox /tmp/ path, got: {result_path}"
)
assert result_path.startswith("astrbot_overflow_"), (
f"Expected sandbox path starting with 'astrbot_overflow_', got: {result_path}"
)

assert result_path.endswith(".txt")

# Must have called write_file on the booter
assert len(write_calls) == 1
assert write_calls[0]["path"] == result_path
assert write_calls[0]["content"] == "hello sandbox"


@pytest.mark.asyncio
async def test_overflow_notice_contains_sandbox_path_not_host_path(monkeypatch):
"""End-to-end: when overflow_file_writer is wired up, the tool-message
notice MUST contain the sandbox-side /tmp/ path, not a host path."""
tool = FunctionTool(
name="test_tool",
description="test",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
read_tool = FunctionTool(
name="astrbot_file_read_tool",
description="read file",
parameters={"type": "object", "properties": {"path": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool, read_tool])
provider = SingleToolThenFinalProvider(tool.name, {"query": "large"})
request = ProviderRequest(prompt="run tool", func_tool=tool_set, contexts=[])
runner = ToolLoopAgentRunner()

sandbox_overflow_path = "/tmp/astrbot_overflow_sandbox_abc12345.txt"

async def _sandbox_writer(content: str, tool_call_id: str) -> str:
return sandbox_overflow_path

await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=cast(
Any, LargeTextToolExecutor.from_text(_make_large_tool_result_text())
),
agent_hooks=MockHooks(),
streaming=False,
tool_result_overflow_dir="/tmp/.astrbot", # host path — should NOT be used
read_tool=read_tool,
overflow_file_writer=_sandbox_writer,
)

async for _ in runner.step_until_done(3):
pass

tool_messages = [m for m in runner.run_context.messages if m.role == "tool"]
assert len(tool_messages) == 1
tool_message_content = str(tool_messages[0].content)

# The notice MUST contain the sandbox path
assert sandbox_overflow_path in tool_message_content, (
f"Expected sandbox path {sandbox_overflow_path!r} in notice"
)
# The host path MUST NOT leak into the notice
assert "/tmp/.astrbot" not in tool_message_content, (
"Host tool_result_overflow_dir path leaked into sandbox-mode notice"
)


if __name__ == "__main__":
# 运行测试
pytest.main([__file__, "-v"])
Loading