Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions astrbot/dashboard/api/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def _to_int(value: Any, default: int) -> int:
return default


def _model_dict(payload) -> dict[str, Any]:
if payload is None:
return {}
if hasattr(payload, "model_dump"):
return payload.model_dump(exclude_none=True)
return payload if isinstance(payload, dict) else {}


async def _run(operation, *, prefix: str):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The helper function _model_dict was removed from this file. However, it is still referenced in multiple other endpoints within the same file:

  • import_knowledge_base_documents (line 210)
  • import_knowledge_base_document_url (line 224)
  • retrieve_knowledge_base (line 304)

Removing _model_dict will cause a NameError when any of these endpoints are called. Please restore _model_dict or refactor those endpoints to use payload.model_dump(exclude_none=True) directly.

def _model_dict(payload) -> dict[str, Any]:
    if payload is None:
        return {}
    if hasattr(payload, "model_dump"):
        return payload.model_dump(exclude_none=True)
    return payload if isinstance(payload, dict) else {}


async def _run(operation, *, prefix: str):

try:
result = await run_maybe_async(operation)
Expand Down Expand Up @@ -94,7 +86,11 @@ async def list_knowledge_bases(
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
page_size=(
_to_int(request.query_params.get("page_size"), 20)
if "page" in request.query_params or "page_size" in request.query_params
else None
),
),
prefix="获取知识库列表失败",
)
Expand All @@ -107,7 +103,7 @@ async def create_knowledge_base(
service: KnowledgeBaseService = Depends(get_service),
):
return await _run(
lambda: service.create_kb(_model_dict(payload)),
lambda: service.create_kb(payload.canonical_payload()),
prefix="创建知识库失败",
)

Expand Down Expand Up @@ -140,9 +136,8 @@ async def update_knowledge_base(
_auth: AuthContext = Depends(require_kb_scope),
service: KnowledgeBaseService = Depends(get_service),
):
body = _model_dict(payload)
return await _run(
lambda: service.update_kb({"kb_id": kb_id, **body}),
lambda: service.update_kb({**payload.canonical_payload(), "kb_id": kb_id}),
prefix="更新知识库失败",
)

Expand Down Expand Up @@ -322,7 +317,11 @@ async def dashboard_list_kbs(
return await _run(
lambda: service.list_kbs(
page=_to_int(request.query_params.get("page"), 1),
page_size=_to_int(request.query_params.get("page_size"), 20),
page_size=(
_to_int(request.query_params.get("page_size"), 20)
if "page" in request.query_params or "page_size" in request.query_params
else None
),
),
prefix="获取知识库列表失败",
)
Expand Down
34 changes: 31 additions & 3 deletions astrbot/dashboard/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,42 @@ class ImMessageRequest(OpenModel):


class KnowledgeBaseRequest(OpenModel):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
kb_id: str | None = None
name: str | None = None
kb_name: str | None = None
description: str | None = None
emoji: str | None = None
embedding_provider_id: str | None = None
rerank_provider_id: str | None = None
chunk_size: int | None = None
chunk_overlap: int | None = None

top_k_dense: int | None = None
top_k_sparse: int | None = None
top_m_final: int | None = None

def canonical_payload(self) -> dict[str, Any]:
"""Return the service-facing knowledge base payload.

Returns:
Dictionary accepted by KnowledgeBaseService.
"""
data = self.model_dump(
exclude_unset=True,
include={
"kb_name",
"description",
"emoji",
"embedding_provider_id",
"rerank_provider_id",
"chunk_size",
"chunk_overlap",
"top_k_dense",
"top_k_sparse",
"top_m_final",
},
)
legacy_name = getattr(self, "name", None)
if data.get("kb_name") is None and legacy_name is not None:
data["kb_name"] = legacy_name
return data

class KnowledgeBaseImportRequest(OpenModel):
Comment on lines 249 to 250

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description mentions syncing KnowledgeBaseCreateRequest Pydantic models, and the OpenAPI spec/frontend types expect it. However, KnowledgeBaseCreateRequest is not defined in astrbot/dashboard/schemas.py.

Please define KnowledgeBaseCreateRequest inheriting from KnowledgeBaseRequest with kb_name and embedding_provider_id as required fields.

Suggested change
class KnowledgeBaseImportRequest(OpenModel):
class KnowledgeBaseCreateRequest(KnowledgeBaseRequest):
kb_name: str
embedding_provider_id: str
class KnowledgeBaseImportRequest(OpenModel):

documents: list[dict[str, Any]] | None = None
Expand Down
63 changes: 41 additions & 22 deletions astrbot/dashboard/services/knowledge_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ def __init__(self, core_lifecycle: AstrBotCoreLifecycle) -> None:
def _payload(data: object) -> dict[str, Any]:
return data if isinstance(data, dict) else {}

@staticmethod
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
def _canonical_kb_payload(data: object) -> dict[str, Any]:
"""Normalize knowledge base create/update payloads.

Args:
data: Request payload from v1 or legacy Dashboard routes.

Returns:
Payload using the service's canonical field names.
"""
payload = KnowledgeBaseService._payload(data).copy()
if payload.get("kb_name") is None and payload.get("name") is not None:
payload["kb_name"] = payload["name"]
payload.pop("name", None)
return payload

def get_kb_manager(self):
return self.core_lifecycle.kb_manager

Expand Down Expand Up @@ -263,19 +279,30 @@ async def background_import_task(
logger.error(traceback.format_exc())
self.set_task_result(task_id, "failed", error=str(exc))

async def list_kbs(self, *, page: int, page_size: int) -> dict[str, Any]:
async def list_kbs(self, *, page: int, page_size: int | None) -> dict[str, Any]:
kb_manager = self.get_kb_manager()
kbs = await kb_manager.list_kbs()

kb_list = []
for kb in kbs:
selected_kbs = kbs
if page_size is not None:
start = max(page - 1, 0) * page_size
end = start + page_size
selected_kbs = kbs[start:end]

for kb in selected_kbs:
kb_dict = kb.model_dump()
kb_helper = await kb_manager.get_kb(kb.kb_id)
if kb_helper and kb_helper.init_error:
kb_dict["init_error"] = kb_helper.init_error
kb_list.append(kb_dict)

return {"items": kb_list, "page": page, "page_size": page_size}
return {
"items": kb_list,
"page": page,
"page_size": page_size if page_size is not None else len(kbs),
"total": len(kbs),
}

async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, Any]:
return await self.list_kbs(
Expand All @@ -285,7 +312,7 @@ async def list_kbs_from_dashboard_query(self, *, page, page_size) -> dict[str, A

async def create_kb(self, data: object) -> tuple[dict[str, Any], str]:
kb_manager = self.get_kb_manager()
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_name = payload.get("kb_name")
if not kb_name:
raise KnowledgeBaseServiceError("知识库名称不能为空")
Expand Down Expand Up @@ -355,7 +382,7 @@ async def get_kb_from_dashboard_query(self, kb_id: str | None) -> dict[str, Any]
return await self.get_kb(kb_id)

async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
payload = self._payload(data)
payload = self._canonical_kb_payload(data)
kb_id = payload.get("kb_id")
if not kb_id:
raise KnowledgeBaseServiceError("缺少参数 kb_id")
Expand All @@ -372,28 +399,20 @@ async def update_kb(self, data: object) -> tuple[dict[str, Any], str]:
"top_k_sparse",
"top_m_final",
]
if all(payload.get(key) is None for key in update_keys):
provided_updates = {key: payload[key] for key in update_keys if key in payload}
if not provided_updates:
raise KnowledgeBaseServiceError("至少需要提供一个更新字段")

current_kb = await self.get_kb_manager().get_kb(kb_id)
kb_name = payload.get("kb_name")
if kb_name is None:
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
kb_name = current_kb.kb.kb_name
if not current_kb:
raise KnowledgeBaseServiceError("知识库不存在")
current = current_kb.kb
update_data = {key: getattr(current, key) for key in update_keys}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr(current, key) without a default value can raise an AttributeError if any of the update_keys are missing from the current object (e.g., due to database schema mismatches or pending migrations).

Using getattr(current, key, None) is safer and prevents potential runtime crashes.

Suggested change
update_data = {key: getattr(current, key) for key in update_keys}
update_data = {key: getattr(current, key, None) for key in update_keys}

update_data.update(provided_updates)

kb_helper = await self.get_kb_manager().update_kb(
kb_id=kb_id,
kb_name=kb_name,
description=payload.get("description"),
emoji=payload.get("emoji"),
embedding_provider_id=payload.get("embedding_provider_id"),
rerank_provider_id=payload.get("rerank_provider_id"),
chunk_size=payload.get("chunk_size"),
chunk_overlap=payload.get("chunk_overlap"),
top_k_dense=payload.get("top_k_dense"),
top_k_sparse=payload.get("top_k_sparse"),
top_m_final=payload.get("top_m_final"),
**update_data,
)
if not kb_helper:
raise KnowledgeBaseServiceError("知识库不存在")
Expand Down Expand Up @@ -738,11 +757,11 @@ async def retrieve(self, data: object) -> dict[str, Any]:

if not query:
raise KnowledgeBaseServiceError("缺少参数 query")
kb_manager = self.get_kb_manager()
if not kb_names or not isinstance(kb_names, list):
raise KnowledgeBaseServiceError("缺少参数 kb_names 或格式错误")

top_k = payload.get("top_k", 5)
kb_manager = self.get_kb_manager()
results = await kb_manager.retrieve(
query=query,
kb_names=kb_names,
Expand Down
22 changes: 15 additions & 7 deletions dashboard/src/api/generated/openapi-v1/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,22 @@ export type JsonSchema = {
[key: string]: unknown;
};

export type KnowledgeBaseCreateRequest = KnowledgeBaseRequest & {
kb_name: string;
embedding_provider_id: string;
};

export type KnowledgeBaseRequest = {
name: string;
kb_name?: string;
description?: string;
embedding_provider_id?: string;
rerank_provider_id?: string;
chunking?: DynamicConfig;
metadata?: DynamicConfig;
emoji?: string;
embedding_provider_id?: (string) | null;
rerank_provider_id?: (string) | null;
chunk_size?: number;
chunk_overlap?: number;
top_k_dense?: number;
top_k_sparse?: number;
top_m_final?: number;
};

export type KnowledgeDocumentImportRequest = {
Expand All @@ -271,7 +280,6 @@ export type KnowledgeDocumentImportRequest = {

export type KnowledgeDocumentUploadRequest = {
file: (Blob | File);
parser?: string;
};

export type KnowledgeDocumentUrlImportRequest = {
Expand Down Expand Up @@ -2569,7 +2577,7 @@ export type ListKnowledgeBasesResponse = (SuccessEnvelope);
export type ListKnowledgeBasesError = unknown;

export type CreateKnowledgeBaseData = {
body: KnowledgeBaseRequest;
body: KnowledgeBaseCreateRequest;
};

export type CreateKnowledgeBaseResponse = (SuccessEnvelope);
Expand Down
10 changes: 6 additions & 4 deletions dashboard/src/api/v1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import {
type DynamicConfig,
type EnabledPatch,
type GhproxyTestRequest,
type KnowledgeBaseCreateRequest,
type KnowledgeBaseRequest,
type LoginRequest,
type ListConversationsData,
type McpServerConfig,
Expand Down Expand Up @@ -1352,16 +1354,16 @@ export const knowledgeApi = {
openApiV1.getKnowledgeBase({ path: { kb_id: kbId } }),
);
},
create(config: OpenConfig) {
create(config: KnowledgeBaseCreateRequest) {
return typed<OpenConfig>(
openApiV1.createKnowledgeBase({ body: config as any }),
openApiV1.createKnowledgeBase({ body: config }),
);
},
update(kbId: string, config: OpenConfig) {
update(kbId: string, config: KnowledgeBaseRequest) {
return typed<OpenConfig>(
openApiV1.updateKnowledgeBase({
path: { kb_id: kbId },
body: config as any,
body: config,
}),
);
},
Expand Down
9 changes: 7 additions & 2 deletions dashboard/src/views/knowledge-base/KBList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@

<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null" hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null"
:rules="[v => editingKB !== null || !!v || t('create.embeddingModelRequired')]" required
hint="嵌入模型选择后无法修改,如需更换请创建新的知识库。" persistent-hint>
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
Expand Down Expand Up @@ -441,7 +443,10 @@ const submitForm = async () => {
if (editingKB.value) {
response = await knowledgeApi.update(editingKB.value.kb_id, payload)
} else {
response = await knowledgeApi.create(payload)
response = await knowledgeApi.create({
...payload,
embedding_provider_id: formData.value.embedding_provider_id!
})
}

if (response.data.status === 'ok') {
Expand Down
39 changes: 28 additions & 11 deletions openspec/openapi-v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3347,7 +3347,7 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/KnowledgeBaseRequest"
$ref: "#/components/schemas/KnowledgeBaseCreateRequest"
responses:
"200":
$ref: "#/components/responses/Ok"
Expand Down Expand Up @@ -5623,31 +5623,48 @@ components:

KnowledgeBaseRequest:
type: object
required: [name]
properties:
name:
kb_name:
type: string
description:
type: string
embedding_provider_id:
emoji:
type: string
embedding_provider_id:
type: [string, "null"]
rerank_provider_id:
type: string
chunking:
$ref: "#/components/schemas/DynamicConfig"
metadata:
$ref: "#/components/schemas/DynamicConfig"
type: [string, "null"]
chunk_size:
type: integer
chunk_overlap:
type: integer
top_k_dense:
type: integer
top_k_sparse:
type: integer
top_m_final:
type: integer
additionalProperties: false

KnowledgeBaseCreateRequest:
allOf:
- $ref: "#/components/schemas/KnowledgeBaseRequest"
- type: object
required: [kb_name, embedding_provider_id]
properties:
kb_name:
type: string
embedding_provider_id:
type: string
additionalProperties: false

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In OpenAPI, combining allOf with additionalProperties: false on a sub-schema will cause validation to fail for any properties defined in the other allOf schemas (such as description, emoji, etc. from KnowledgeBaseRequest). This is because the validator evaluates the sub-schema strictly and rejects any property not explicitly listed in its own properties block.

Please remove additionalProperties: false from the inline schema of KnowledgeBaseCreateRequest.

    KnowledgeBaseCreateRequest:
      allOf:
        - $ref: "#/components/schemas/KnowledgeBaseRequest"
        - type: object
          required: [kb_name, embedding_provider_id]
          properties:
            kb_name:
              type: string
            embedding_provider_id:
              type: string


KnowledgeDocumentUploadRequest:
type: object
required: [file]
properties:
file:
type: string
format: binary
parser:
type: string

KnowledgeDocumentImportRequest:
type: object
Expand Down
Loading
Loading