feat: #71 RAG Chat SSE 流式输出 + DeepSeek V4 Pro 思考过程
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 20s
Some checks failed
Deploy API Server / build-and-deploy (push) Failing after 20s
- AiProvider 接口新增 StreamChunk 类型 + generateStream() 方法 - DeepSeekProvider 实现 generateStream():stream=true,读 reader 逐 chunk yield - AiGatewayService 新增 generateStream():透传 provider stream + 记录用量 - RagChatService 新增 sendMessageStream():流式调用 + 保存最终 AI 回复到 DB - POST /api/rag-chat/sessions/:id/stream 新 SSE endpoint - thinking chunk:DeepSeek V4 Pro reasoning_content → type: "thinking" - 流式模式下禁用 response_format json_object,不阻塞思考过程 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
f4de598d96
commit
6f77162cf8
@ -10,6 +10,7 @@ import { BaseDomainEvent } from '../../../common/events/base-domain.event';
|
||||
import { PrismaService } from '../../../infrastructure/database/prisma.service';
|
||||
import type { AiProvider } from '../providers/ai-provider.interface';
|
||||
import type { GatewayRequest, GatewayResponse, ModelTier } from './ai-gateway.types';
|
||||
import type { StreamChunk } from '../providers/ai-provider.interface';
|
||||
|
||||
class AIUsageRecorded extends BaseDomainEvent {
|
||||
eventType = 'ai.usage.recorded';
|
||||
@ -207,6 +208,63 @@ export class AiGatewayService {
|
||||
}
|
||||
}
|
||||
|
||||
async *generateStream(request: GatewayRequest, timeoutMs = this.DEFAULT_TIMEOUT_MS): AsyncGenerator<StreamChunk, void, undefined> {
|
||||
const tierConfig = this.modelRouter.resolve(request.tier);
|
||||
const prompt = this.promptTemplate.get(request.promptKey, request.promptVersion);
|
||||
const messages = [
|
||||
{ role: 'system' as const, content: prompt.systemPrompt },
|
||||
...request.messages,
|
||||
];
|
||||
|
||||
const target = tierConfig.preferred;
|
||||
const provider = this.resolveProviderForTarget(target.provider);
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
|
||||
|
||||
let fullContent = '';
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
try {
|
||||
for await (const chunk of provider.generateStream!({
|
||||
model: target.model,
|
||||
messages,
|
||||
temperature: 0.3,
|
||||
maxTokens: request.maxTokens ?? 4096,
|
||||
signal: controller.signal,
|
||||
})) {
|
||||
if (chunk.type === 'error') {
|
||||
yield { type: 'error', content: chunk.content };
|
||||
return;
|
||||
}
|
||||
if (chunk.type === 'done' && chunk.usage) {
|
||||
inputTokens = chunk.usage.inputTokens;
|
||||
outputTokens = chunk.usage.outputTokens;
|
||||
}
|
||||
if (chunk.type === 'content') {
|
||||
fullContent += (chunk.content || '');
|
||||
}
|
||||
yield chunk;
|
||||
}
|
||||
|
||||
// Record usage after stream completes
|
||||
if (inputTokens > 0) {
|
||||
const estimatedCost = this.costCalculator.calculate(target.provider, target.model, inputTokens, outputTokens);
|
||||
this.usageLog.log({
|
||||
userId: request.userId, feature: request.feature,
|
||||
provider: target.provider, model: target.model, tier: request.tier,
|
||||
promptKey: request.promptKey, promptVersion: prompt.version,
|
||||
inputTokens, outputTokens, estimatedCost, latencyMs: 0, success: true,
|
||||
}).catch(() => {});
|
||||
}
|
||||
} catch (error: any) {
|
||||
yield { type: 'error', content: error.message };
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
private buildSystemPrompt(systemPrompt: string, schemaDesc: string): string {
|
||||
return `${systemPrompt}\n\n请严格按照以下 JSON Schema 输出,只输出 JSON,不要包含其他内容:\n${schemaDesc}`;
|
||||
}
|
||||
|
||||
@ -17,7 +17,14 @@ export interface AiGenerateOutput {
|
||||
latencyMs: number;
|
||||
}
|
||||
|
||||
export interface StreamChunk {
|
||||
type: 'thinking' | 'content' | 'error' | 'done';
|
||||
content?: string;
|
||||
usage?: { inputTokens: number; outputTokens: number };
|
||||
}
|
||||
|
||||
export interface AiProvider {
|
||||
readonly name: string;
|
||||
generate(input: AiGenerateInput): Promise<AiGenerateOutput>;
|
||||
generateStream?(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined>;
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { ConfigService } from '@nestjs/config';
|
||||
import type { AiProvider, AiGenerateInput, AiGenerateOutput } from './ai-provider.interface';
|
||||
import type { AiProvider, AiGenerateInput, AiGenerateOutput, StreamChunk } from './ai-provider.interface';
|
||||
|
||||
@Injectable()
|
||||
export class DeepSeekProvider implements AiProvider {
|
||||
@ -64,4 +64,95 @@ export class DeepSeekProvider implements AiProvider {
|
||||
latencyMs,
|
||||
};
|
||||
}
|
||||
|
||||
async *generateStream(input: AiGenerateInput): AsyncGenerator<StreamChunk, void, undefined> {
|
||||
if (!this.apiKey) {
|
||||
yield { type: 'error', content: 'DeepSeek API key not configured' };
|
||||
return;
|
||||
}
|
||||
|
||||
const body: Record<string, any> = {
|
||||
model: input.model,
|
||||
messages: input.messages,
|
||||
temperature: input.temperature ?? 0.3,
|
||||
max_tokens: input.maxTokens ?? 4096,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// When streaming, do NOT use response_format json_object — it disables reasoning_content
|
||||
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: input.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => 'unknown');
|
||||
this.logger.error(`DeepSeek stream error ${response.status}: ${errorText}`);
|
||||
yield { type: 'error', content: `DeepSeek API returned ${response.status}` };
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
yield { type: 'error', content: 'No response body' };
|
||||
return;
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue;
|
||||
const data = line.slice(6).trim();
|
||||
if (data === '[DONE]') {
|
||||
yield { type: 'done', usage: { inputTokens, outputTokens } };
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const json = JSON.parse(data);
|
||||
const delta = json.choices?.[0]?.delta;
|
||||
if (!delta) continue;
|
||||
|
||||
// Track usage
|
||||
if (json.usage) {
|
||||
inputTokens = json.usage.prompt_tokens ?? inputTokens;
|
||||
outputTokens = json.usage.completion_tokens ?? outputTokens;
|
||||
}
|
||||
|
||||
// reasoning_content = thinking process (DeepSeek V4 Pro)
|
||||
if (delta.reasoning_content) {
|
||||
yield { type: 'thinking', content: delta.reasoning_content };
|
||||
}
|
||||
|
||||
// content = actual response
|
||||
if (delta.content) {
|
||||
yield { type: 'content', content: delta.content };
|
||||
}
|
||||
} catch {
|
||||
// Skip unparseable chunks
|
||||
}
|
||||
}
|
||||
}
|
||||
yield { type: 'done', usage: { inputTokens, outputTokens } };
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { Controller, Get, Post, Delete, Body, Param, UseGuards } from '@nestjs/common';
|
||||
import { Controller, Get, Post, Delete, Body, Param, Res } from '@nestjs/common';
|
||||
import { ApiTags, ApiOperation, ApiBearerAuth } from '@nestjs/swagger';
|
||||
import { Response } from 'express';
|
||||
import { RagChatService } from './rag-chat.service';
|
||||
import { CurrentUser } from '../../common/decorators/current-user.decorator';
|
||||
import type { UserPayload } from '../../common/types';
|
||||
@ -29,11 +30,27 @@ export class RagChatController {
|
||||
}
|
||||
|
||||
@Post('sessions/:id/messages')
|
||||
@ApiOperation({ summary: '发送消息' })
|
||||
@ApiOperation({ summary: '发送消息(同步)' })
|
||||
async sendMessage(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }) {
|
||||
return this.svc.sendMessage(String(user.id), id, dto.content);
|
||||
}
|
||||
|
||||
@Post('sessions/:id/stream')
|
||||
@ApiOperation({ summary: '发送消息(SSE 流式,支持思考过程)' })
|
||||
async sendMessageStream(@CurrentUser() user: UserPayload, @Param('id') id: string, @Body() dto: { content: string }, @Res() res: Response) {
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
|
||||
const userId = String(user.id);
|
||||
for await (const chunk of this.svc.sendMessageStream(userId, id, dto.content)) {
|
||||
if (res.destroyed) break;
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
}
|
||||
res.end();
|
||||
}
|
||||
|
||||
@Delete('sessions/:id')
|
||||
@ApiOperation({ summary: '删除对话' })
|
||||
async deleteSession(@Param('id') id: string) {
|
||||
|
||||
@ -3,6 +3,7 @@ import { PrismaService } from '../../infrastructure/database/prisma.service';
|
||||
import { ContentSafetyService } from '../content-safety/content-safety.service';
|
||||
import { AiGatewayService } from '../ai/gateway/ai-gateway.service';
|
||||
import { RagChatOutputSchema } from '../ai/prompts/schemas/rag-chat.schema';
|
||||
import type { StreamChunk } from '../ai/providers/ai-provider.interface';
|
||||
|
||||
const MAX_CONTEXT_CHARS = 4000;
|
||||
|
||||
@ -113,6 +114,59 @@ export class RagChatService {
|
||||
return { message: aiMsg, citations };
|
||||
}
|
||||
|
||||
async *sendMessageStream(userId: string, sessionId: string, content: string): AsyncGenerator<StreamChunk, void, undefined> {
|
||||
const session = await this.prisma.chatSession.findUnique({ where: { id: sessionId } });
|
||||
if (!session || session.userId !== userId) {
|
||||
yield { type: 'error', content: '对话不存在' };
|
||||
return;
|
||||
}
|
||||
|
||||
const inputCheck = await this.safety?.check(content, { userId, contentType: 'rag_input' });
|
||||
if (inputCheck && !inputCheck.safe) {
|
||||
yield { type: 'error', content: '输入包含违规内容' };
|
||||
return;
|
||||
}
|
||||
|
||||
// Save user message
|
||||
await this.prisma.chatMessage.create({ data: { sessionId, role: 'user', content } });
|
||||
|
||||
// Load context
|
||||
const context = await this.loadContext(session.knowledgeBaseId);
|
||||
if (!context.text) {
|
||||
yield { type: 'content', content: this.fallbackReply(true) };
|
||||
} else {
|
||||
const messages = [
|
||||
{ role: 'system' as const, content: this.buildSystemPrompt(context.text) },
|
||||
{ role: 'user' as const, content },
|
||||
];
|
||||
|
||||
let fullContent = '';
|
||||
for await (const chunk of this.aiGateway!.generateStream({
|
||||
feature: 'rag-chat', userId, tier: 'primary',
|
||||
promptKey: 'rag-chat', promptVersion: 'v1', messages, maxTokens: 2048,
|
||||
})) {
|
||||
if (chunk.type === 'content' && chunk.content) {
|
||||
fullContent += chunk.content;
|
||||
}
|
||||
yield chunk;
|
||||
}
|
||||
|
||||
// Save AI reply
|
||||
if (fullContent) {
|
||||
const aiMsg = await this.prisma.chatMessage.create({
|
||||
data: { sessionId, role: 'ai', content: fullContent, tokens: fullContent.length },
|
||||
});
|
||||
for (const c of context.citations.slice(0, 5)) {
|
||||
await this.prisma.chatCitation.create({
|
||||
data: { messageId: aiMsg.id, chunkId: c.id, sourceTitle: c.title ?? null, excerptText: c.text.slice(0, 500) },
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await this.prisma.chatSession.update({ where: { id: sessionId }, data: { updatedAt: new Date() } });
|
||||
}
|
||||
|
||||
async deleteSession(sessionId: string) {
|
||||
await this.prisma.chatCitation.deleteMany({ where: { message: { sessionId } } });
|
||||
await this.prisma.chatMessage.deleteMany({ where: { sessionId } });
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user