Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions astrbot/builtin_stars/astrbot/group_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)

"""
Group chat context awareness.
Expand Down Expand Up @@ -78,6 +82,9 @@ def cfg(self, event: AstrMessageEvent):
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
"image_caption_cache_ttl": resolve_image_caption_cache_ttl(
cfg.get("provider_settings", {})
),
"enable_active_reply": enable_active_reply,
"ar_method": ar_method,
"ar_possibility": ar_possibility,
Expand All @@ -90,17 +97,46 @@ async def get_image_caption(
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
cache_ttl: int = 0,
) -> str:
if not image_caption_provider_id:
provider = self.context.get_using_provider()
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
raise Exception(
f"Provider `{image_caption_provider_id}` was not found."
)

if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
response = await provider.text_chat(
raise Exception(
f"Provider type is invalid for image captioning: {type(provider)}."
)
provider_id = _resolve_provider_cache_identity(
provider,
configured_provider_id=image_caption_provider_id,
Comment thread
FloranceYeh marked this conversation as resolved.
)

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=image_caption_prompt,
image_urls=[image_url],
ttl_seconds=cache_ttl,
caption_factory=lambda: self._fetch_image_caption(
provider,
image_caption_prompt,
image_url,
),
)

async def _fetch_image_caption(
self,
provider: Provider,
prompt: str,
image_url: str,
) -> str:
response = await provider.text_chat(
prompt=prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
Expand Down Expand Up @@ -206,15 +242,16 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
try:
url = comp.url if comp.url else comp.file
if not url:
raise Exception("图片 URL 为空")
raise Exception("Image URL is empty.")
caption = await self.get_image_caption(
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
cfg["image_caption_cache_ttl"],
)
parts.append(f" [Image: {caption}]")
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
logger.error(f"Failed to get image caption: {e}")
else:
parts.append(" [Image]")
elif isinstance(comp, At):
Expand All @@ -223,7 +260,7 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
"all",
)
if is_at_self:
parts.insert(1, "⚠️[DIRECTED AT YOU] ")
parts.insert(1, "[DIRECTED AT YOU] ")
parts.append(f" [At: {comp.name}]")
elif isinstance(comp, Reply):
if comp.message_str:
Expand Down Expand Up @@ -300,3 +337,27 @@ def _trim_left(

def _format_group_history_block(records: list[str]) -> str:
return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER


def _resolve_provider_cache_identity(
provider: Provider,
configured_provider_id: str,
) -> str:
if configured_provider_id:
return configured_provider_id

provider_config = provider.provider_config or {}
provider_id = provider_config.get("id", "")
Comment thread
FloranceYeh marked this conversation as resolved.
Outdated
if isinstance(provider_id, str) and provider_id:
return provider_id

provider_type = provider_config.get("type", "")
model = provider.get_model()
return ":".join(
[
provider.__class__.__module__,
provider.__class__.__qualname__,
"" if provider_type is None else str(provider_type),
"" if model is None else str(model),
]
)
64 changes: 55 additions & 9 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
get_astrbot_workspaces_path,
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
Expand Down Expand Up @@ -614,11 +618,41 @@ async def _ensure_persona_and_skills(
pass


async def _request_img_caption_with_provider(
prov: Provider,
provider_id: str,
image_urls: list[str],
prompt: str,
cache_ttl: int | None = None,
) -> str:
if cache_ttl is None:
cache_ttl = resolve_image_caption_cache_ttl(
prov.provider_config if isinstance(prov.provider_config, dict) else None
)
logger.debug("Processing image caption with provider: %s", provider_id)

async def _caption_factory() -> str:
llm_resp = await prov.text_chat(
prompt=prompt,
image_urls=image_urls,
)
return llm_resp.completion_text
Comment thread
FloranceYeh marked this conversation as resolved.
Outdated

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=prompt,
image_urls=image_urls,
ttl_seconds=cache_ttl,
caption_factory=_caption_factory,
)


async def _request_img_caption(
provider_id: str,
cfg: dict,
image_urls: list[str],
plugin_context: Context,
prompt: str | None = None,
) -> str:
prov = plugin_context.get_provider_by_id(provider_id)
if prov is None:
Expand All @@ -630,16 +664,18 @@ async def _request_img_caption(
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.",
)

img_cap_prompt = cfg.get(
img_cap_prompt = prompt or cfg.get(
"image_caption_prompt",
"Please describe the image.",
)
logger.debug("Processing image caption with provider: %s", provider_id)
llm_resp = await prov.text_chat(
prompt=img_cap_prompt,
cache_ttl = resolve_image_caption_cache_ttl(cfg)
return await _request_img_caption_with_provider(
prov=prov,
provider_id=provider_id,
image_urls=image_urls,
prompt=img_cap_prompt,
cache_ttl=cache_ttl,
)
return llm_resp.completion_text


async def _ensure_img_caption(
Expand Down Expand Up @@ -865,13 +901,23 @@ async def _process_quote_message(
path, compress_path
):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
provider_config = (
prov.provider_config
if isinstance(prov.provider_config, dict)
else {}
)
caption = await _request_img_caption_with_provider(
prov=prov,
provider_id=provider_config.get("id", img_cap_prov_id or ""),
image_urls=[compress_path],
prompt="Please describe the image content.",
cache_ttl=resolve_image_caption_cache_ttl(
config.provider_settings if config else None
),
)
if llm_resp.completion_text:
if caption:
content_parts.append(
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
f"[Image Caption in quoted message]: {caption}"
)
else:
logger.warning("No provider found for image captioning in quote.")
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"request_max_retries": 5,
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"image_caption_cache_ttl": 600,
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "",
"web_search": False,
Expand Down Expand Up @@ -3223,6 +3224,11 @@
"description": "图片转述提示词",
"type": "text",
},
"provider_settings.image_caption_cache_ttl": {
"description": "图片转述缓存时长(秒)",
"type": "int",
"hint": "在缓存时间内再次收到相同图片时,直接复用已缓存的视觉识别结果;设为 0 表示禁用缓存",
},
},
"condition": {
"provider_settings.enable": True,
Expand Down
Loading