api-server/src/modules/rag-chat/rag-chat.service.ts
wangdl f4de598d96
All checks were successful
Deploy API Server / build-and-deploy (push) Successful in 42s
fix: rag-chat 传入 outputSchema= RagChatOutputSchema,修复 parsed 为空对象
parseJson 无 schema 时直接返回 {},导致 resp.parsed?.answer 始终为 null。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-06 14:30:53 +08:00

174 lines
6.3 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { Injectable, NotFoundException, Logger, Optional } from '@nestjs/common';
import { PrismaService } from '../../infrastructure/database/prisma.service';
import { ContentSafetyService } from '../content-safety/content-safety.service';
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema';
const MAX_CONTEXT_CHARS = 4000;
@Injectable()
export class RagChatService {
private readonly logger = new Logger(RagChatService.name);
constructor(
private readonly prisma: PrismaService,
@Optional() private readonly safety?: ContentSafetyService,
@Optional() private readonly aiGateway?: AiGatewayService,
) {}
async createSession(userId: string, knowledgeBaseId: string, title?: string) {
return this.prisma.chatSession.create({
data: { userId, knowledgeBaseId, title: title || '新对话' },
});
}
async listSessions(userId: string, kbId?: string) {
return this.prisma.chatSession.findMany({
where: { userId, ...(kbId ? { knowledgeBaseId: kbId } : {}) },
orderBy: { updatedAt: 'desc' },
});
}
async getMessages(sessionId: string) {
return this.prisma.chatMessage.findMany({
where: { sessionId },
orderBy: { createdAt: 'asc' },
include: { citations: true },
});
}
async sendMessage(userId: string, sessionId: string, content: string) {
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
if (!session || session.userId !== userId) throw new NotFoundException('对话不存在');
// Content safety
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
if (inputCheck && !inputCheck.safe) {
return { blocked: true, message: '输入包含违规内容,请修改后重试' };
}
// Save user message
await this.prisma.chatMessage.create({
data: { sessionId, role: 'user', content },
});
// Retrieve knowledge base context
this.logger.log(`RAG: kbId=${session.knowledgeBaseId}, content preview: ${content.substring(0, 30)}`);
const context = await this.loadContext(session.knowledgeBaseId);
this.logger.log(`RAG context: isEmpty=${context.isEmpty}, textLen=${context.text.length}, citations=${context.citations.length}, aiGateway=${!!this.aiGateway}`);
// Generate AI response
let reply: string;
let citations: any[] = [];
if (this.aiGateway && context.text) {
try {
this.logger.log(`Calling AI Gateway with ${context.text.length} chars context`);
const messages = [
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
{ role: 'user' as const, content },
];
const resp = await this.aiGateway.generate({
feature: 'rag-chat',
userId,
tier: 'primary',
promptKey: 'rag-chat',
promptVersion: 'v1',
messages,
maxTokens: 2048,
outputSchema: RagChatOutputSchema,
});
this.logger.log(`AI Gateway response: parsed=${!!resp.parsed}, keys=${resp.parsed ? Object.keys(resp.parsed).join(',') : 'null'}, raw=${JSON.stringify(resp.parsed).substring(0, 300)}`);
reply = resp.parsed?.answer ?? String(resp.parsed?.content ?? '抱歉AI 暂时无法生成回答。');
citations = context.citations;
} catch (err: any) {
this.logger.error(`AI Gateway FAILED: ${err?.message}`, err?.stack?.substring(0, 300));
reply = this.fallbackReply(context.isEmpty);
}
} else {
this.logger.warn(`Falling back: aiGateway=${!!this.aiGateway}, hasText=${!!context.text}`);
reply = this.fallbackReply(context.isEmpty);
}
// Save AI message
const aiMsg = await this.prisma.chatMessage.create({
data: { sessionId, role: 'ai', content: reply, tokens: reply.length },
});
// Save citations
for (const c of citations.slice(0, 5)) {
await this.prisma.chatCitation.create({
data: {
messageId: aiMsg.id,
chunkId: c.id,
sourceTitle: c.title ?? null,
excerptText: c.text.slice(0, 500),
},
});
}
// Update session timestamp
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
return { message: aiMsg, citations };
}
async deleteSession(sessionId: string) {
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
await this.prisma.chatSession.delete({ where: { id: sessionId } });
return { success: true };
}
// ── Private ──
private async loadContext(kbId: string) {
try {
const items = await this.prisma.knowledgeItem.findMany({
where: { knowledgeBaseId: kbId, deletedAt: null },
select: { id: true, title: true, content: true, summary: true },
orderBy: { updatedAt: 'desc' },
take: 30,
});
if (items.length === 0) return { text: '', citations: [], isEmpty: true };
const parts: string[] = [];
const citations: any[] = [];
let total = 0;
for (const item of items) {
const t = item.content || item.summary || '';
if (!t || total >= MAX_CONTEXT_CHARS) break;
const snippet = t.slice(0, Math.min(t.length, 500));
parts.push(`${item.title}${snippet}`);
citations.push({ id: item.id, text: snippet, score: 1.0, title: item.title });
total += snippet.length;
}
return { text: parts.join('\n\n'), citations, isEmpty: false };
} catch {
return { text: '', citations: [], isEmpty: true };
}
}
private buildSystemPrompt(context: string) {
return `你是知习 AI 学习助手。基于以下知识库内容回答用户问题,回答应准确、简洁、有依据。
## 知识库内容
${context}
## 回答要求
- 基于提供的知识库内容回答,不要编造信息
- 如果知识库内容不足以回答问题,请诚实告知
- 回答时可以用「根据知识库中的《xxx》...」引用来源`;
}
private fallbackReply(isEmpty: boolean) {
if (isEmpty) {
return '当前知识库还没有知识点内容。请先上传资料或添加知识点,我就可以基于它们回答你的问题了。';
}
return '抱歉AI 服务暂时不可用,请稍后再试。';
}
}