diff --git a/prisma/schema.prisma b/prisma/schema.prisma index eea0bab..501f332 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -1001,6 +1001,7 @@ model ChatSession { userId String knowledgeBaseId String title String @default("新对话") @db.VarChar(200) + knowledgeItemIds Json? @default("[]") createdAt DateTime @default(now()) updatedAt DateTime @updatedAt diff --git a/src/modules/rag-chat/rag-chat.controller.ts b/src/modules/rag-chat/rag-chat.controller.ts index 0c07e43..a71dae6 100644 --- a/src/modules/rag-chat/rag-chat.controller.ts +++ b/src/modules/rag-chat/rag-chat.controller.ts @@ -13,8 +13,8 @@ export class RagChatController { @Post('sessions') @ApiOperation({ summary: '创建对话' }) - async createSession(@CurrentUser() user: UserPayload, @Body() dto: { knowledgeBaseId: string; title?: string }) { - return this.svc.createSession(String(user.id), dto.knowledgeBaseId, dto.title); + async createSession(@CurrentUser() user: UserPayload, @Body() dto: { knowledgeBaseId: string; title?: string; knowledgeItemIds?: string[] }) { + return this.svc.createSession(String(user.id), dto.knowledgeBaseId, dto.title, dto.knowledgeItemIds); } @Get('sessions') diff --git a/src/modules/rag-chat/rag-chat.service.ts b/src/modules/rag-chat/rag-chat.service.ts index 2dd1df0..e64c854 100644 --- a/src/modules/rag-chat/rag-chat.service.ts +++ b/src/modules/rag-chat/rag-chat.service.ts @@ -17,9 +17,14 @@ export class RagChatService { @Optional() private readonly aiGateway?: AiGatewayService, ) {} - async createSession(userId: string, knowledgeBaseId: string, title?: string) { + async createSession(userId: string, knowledgeBaseId: string, title?: string, knowledgeItemIds?: string[]) { return this.prisma.chatSession.create({ - data: { userId, knowledgeBaseId, title: title || '新对话' }, + data: { + userId, + knowledgeBaseId, + title: title || '新对话', + knowledgeItemIds: knowledgeItemIds ?? [], + }, }); } @@ -55,7 +60,8 @@ export class RagChatService { // Retrieve knowledge base context this.logger.log(`RAG: kbId=${session.knowledgeBaseId}, content preview: ${content.substring(0, 30)}`); - const context = await this.loadContext(session.knowledgeBaseId); + const itemIds = (session as any).knowledgeItemIds as string[] | undefined; + const context = await this.loadContext(session.knowledgeBaseId, itemIds?.length ? itemIds : undefined); this.logger.log(`RAG context: isEmpty=${context.isEmpty}, textLen=${context.text.length}, citations=${context.citations.length}, aiGateway=${!!this.aiGateway}`); // Generate AI response @@ -141,7 +147,8 @@ export class RagChatService { // Also auto-title in sendMessage (this is the sync method) // Load context - const context = await this.loadContext(session.knowledgeBaseId); + const itemIds = (session as any).knowledgeItemIds as string[] | undefined; + const context = await this.loadContext(session.knowledgeBaseId, itemIds?.length ? itemIds : undefined); if (!context.text) { yield { type: 'content', content: this.fallbackReply(true) }; } else { @@ -186,10 +193,14 @@ export class RagChatService { // ── Private ── - private async loadContext(kbId: string) { + private async loadContext(kbId: string, itemIds?: string[]) { try { const items = await this.prisma.knowledgeItem.findMany({ - where: { knowledgeBaseId: kbId, deletedAt: null }, + where: { + knowledgeBaseId: kbId, + deletedAt: null, + ...(itemIds && itemIds.length > 0 ? { id: { in: itemIds } } : {}), + }, select: { id: true, title: true, content: true, summary: true }, orderBy: { updatedAt: 'desc' }, take: 30,