diff --git a/src/modules/rag-chat/rag-chat.service.ts b/src/modules/rag-chat/rag-chat.service.ts index 742f83d..0e0642b 100644 --- a/src/modules/rag-chat/rag-chat.service.ts +++ b/src/modules/rag-chat/rag-chat.service.ts @@ -1,6 +1,9 @@ import { Injectable, NotFoundException, Logger, Optional } from '@nestjs/common'; import { PrismaService } from '../../infrastructure/database/prisma.service'; import { ContentSafetyService } from '../content-safety/content-safety.service'; +import { AiGatewayService } from '../ai/gateway/ai-gateway.service'; + +const MAX_CONTEXT_CHARS = 4000; @Injectable() export class RagChatService { @@ -9,6 +12,7 @@ export class RagChatService { constructor( private readonly prisma: PrismaService, @Optional() private readonly safety?: ContentSafetyService, + @Optional() private readonly aiGateway?: AiGatewayService, ) {} async createSession(userId: string, knowledgeBaseId: string, title?: string) { @@ -36,7 +40,7 @@ export class RagChatService { const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } }); if (!session || session.userId !== userId) throw new NotFoundException('对话不存在'); - // Content safety check on user input + // Content safety const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' }); if (inputCheck && !inputCheck.safe) { return { blocked: true, message: '输入包含违规内容,请修改后重试' }; @@ -47,16 +51,59 @@ export class RagChatService { data: { sessionId, role: 'user', content }, }); - // Generate AI response (simplified — real RAG pipeline in M3) - const reply = `感谢提问。基于知识库内容,我暂时无法生成完整回答(RAG 检索管道将在后续版本完善)。`; + // Retrieve knowledge base context + const context = await this.loadContext(session.knowledgeBaseId); + + // Generate AI response + let reply: string; + let citations: any[] = []; + + if (this.aiGateway && context.text) { + try { + const messages = [ + { role: 'system' as const, content: this.buildSystemPrompt(context.text) }, + { role: 'user' as const, content }, + ]; + const resp = await this.aiGateway.generate({ + feature: 'rag-chat', + userId, + tier: 'primary', + promptKey: 'rag-chat', + promptVersion: 'v1', + messages, + maxTokens: 2048, + }); + reply = resp.parsed?.answer ?? String(resp.parsed?.content ?? '抱歉,AI 暂时无法生成回答。'); + citations = context.citations; + } catch (err: any) { + this.logger.error('AI Gateway failed, falling back', err?.message); + reply = this.fallbackReply(context.isEmpty); + } + } else { + reply = this.fallbackReply(context.isEmpty); + } + + // Save AI message const aiMsg = await this.prisma.chatMessage.create({ data: { sessionId, role: 'ai', content: reply, tokens: reply.length }, }); + // Save citations + for (const c of citations.slice(0, 5)) { + await this.prisma.chatCitation.create({ + data: { + messageId: aiMsg.id, + chunkId: c.id, + content: c.text.slice(0, 500), + score: c.score ?? 0, + }, + }); + } + // Update session timestamp await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } }); - return { message: aiMsg, citations: [] }; + return { message: aiMsg, citations }; } async deleteSession(sessionId: string) { @@ -65,4 +112,55 @@ export class RagChatService { await this.prisma.chatSession.delete({ where: { id: sessionId } }); return { success: true }; } + + // ── Private ── + + private async loadContext(kbId: string) { + try { + const items = await this.prisma.knowledgeItem.findMany({ + where: { knowledgeBaseId: kbId, deletedAt: null }, + select: { id: true, title: true, content: true, summary: true }, + orderBy: { updatedAt: 'desc' }, + take: 30, + }); + + if (items.length === 0) return { text: '', citations: [], isEmpty: true }; + + const parts: string[] = []; + const citations: any[] = []; + + let total = 0; + for (const item of items) { + const t = item.content || item.summary || ''; + if (!t || total >= MAX_CONTEXT_CHARS) break; + const snippet = t.slice(0, Math.min(t.length, 500)); + parts.push(`【${item.title}】${snippet}`); + citations.push({ id: item.id, text: snippet, score: 1.0, title: item.title }); + total += snippet.length; + } + + return { text: parts.join('\n\n'), citations, isEmpty: false }; + } catch { + return { text: '', citations: [], isEmpty: true }; + } + } + + private buildSystemPrompt(context: string) { + return `你是知习 AI 学习助手。基于以下知识库内容回答用户问题,回答应准确、简洁、有依据。 + +## 知识库内容 +${context} + +## 回答要求 +- 基于提供的知识库内容回答,不要编造信息 +- 如果知识库内容不足以回答问题,请诚实告知 +- 回答时可以用「根据知识库中的《xxx》...」引用来源`; + } + + private fallbackReply(isEmpty: boolean) { + if (isEmpty) { + return '当前知识库还没有知识点内容。请先上传资料或添加知识点,我就可以基于它们回答你的问题了。'; + } + return '抱歉,AI 服务暂时不可用,请稍后再试。'; + } }