Commit 9e9c9872 by zhangxingmin

push

parent 4fdc510f
...@@ -5,10 +5,12 @@ import org.mybatis.spring.annotation.MapperScan; ...@@ -5,10 +5,12 @@ import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.cloud.openfeign.EnableFeignClients; import org.springframework.cloud.openfeign.EnableFeignClients;
import org.springframework.scheduling.annotation.EnableAsync;
@SpringBootApplication(scanBasePackages = "com.yd") @SpringBootApplication(scanBasePackages = "com.yd")
@MapperScan("com.yd.**.dao") @MapperScan("com.yd.**.dao")
@EnableFeignClients(basePackages = "com.yd") @EnableFeignClients(basePackages = "com.yd")
@EnableAsync // 开启异步执行支持
public class AiApiApplication { public class AiApiApplication {
public static void main(String[] args) { public static void main(String[] args) {
......
package com.yd.ai.api.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
@Configuration
public class AsyncConfig {
@Bean(name = "aiStreamExecutor")
public Executor taskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(5); // 核心线程数
executor.setMaxPoolSize(10); // 最大线程数
executor.setQueueCapacity(100); // 队列容量
executor.setThreadNamePrefix("ai-stream-"); // 线程名前缀
executor.initialize();
return executor;
}
}
\ No newline at end of file
package com.yd.ai.api.controller; package com.yd.ai.api.controller;
import com.yd.ai.api.service.ApiAiStreamService; import com.yd.ai.api.service.ApiAiStreamService;
import com.yd.common.utils.RedisUtil;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
@RestController @RestController
@RequestMapping("/api/ai") @RequestMapping("/api/ai")
...@@ -13,14 +17,38 @@ public class ApiAiStreamController { ...@@ -13,14 +17,38 @@ public class ApiAiStreamController {
@Autowired @Autowired
private ApiAiStreamService apiAiStreamService; private ApiAiStreamService apiAiStreamService;
@Autowired
private RedisUtil redisUtil;
private static final String REDIS_KEY_PREFIX = "ai:stream:";
private static final long EXPIRE_SECONDS = 300; // 5分钟
/** /**
* 调用大模型接口获取AI回答(流式) * 启动流式生成,返回 sessionId 供前端轮询
*/ */
@CrossOrigin(origins = "*") @PostMapping("/start-stream")
@GetMapping(value = "/stream-sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public ResponseEntity<Map<String, String>> startStream(@RequestParam String question) {
public Flux<String> streamChatSse(@RequestParam String question) { String sessionId = UUID.randomUUID().toString();
return apiAiStreamService.streamChatWithSensitiveCheck(question) // 异步执行流式生成,不阻塞主线程
.onErrorResume(e -> Flux.just("系统繁忙,请稍后重试")); apiAiStreamService.generateAndStore(sessionId, question);
Map<String, String> result = new HashMap<>();
result.put("sessionId", sessionId);
return ResponseEntity.ok(result);
} }
/**
* 轮询获取生成内容
*/
@GetMapping("/stream-content")
public ResponseEntity<Map<String, Object>> getStreamContent(@RequestParam String sessionId) {
String redisKey = REDIS_KEY_PREFIX + sessionId;
String content = redisUtil.getCacheObject(redisKey);
String doneKey = redisKey + ":done";
String done = redisUtil.getCacheObject(doneKey);
Map<String, Object> result = new HashMap<>();
result.put("content", content != null ? content : "");
result.put("finished", "true".equals(done));
return ResponseEntity.ok(result);
}
} }
\ No newline at end of file
package com.yd.ai.api.service; package com.yd.ai.api.service;
import org.springframework.scheduling.annotation.Async;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
public interface ApiAiStreamService { public interface ApiAiStreamService {
Flux<String> streamChatWithSensitiveCheck(String question); @Async("aiStreamExecutor")
void generateAndStore(String sessionId, String question);
} }
\ No newline at end of file
...@@ -13,24 +13,26 @@ import com.yd.auth.core.dto.AuthUserDto; ...@@ -13,24 +13,26 @@ import com.yd.auth.core.dto.AuthUserDto;
import com.yd.auth.core.utils.SecurityUtil; import com.yd.auth.core.utils.SecurityUtil;
import com.yd.common.enums.ResultCode; import com.yd.common.enums.ResultCode;
import com.yd.common.exception.BusinessException; import com.yd.common.exception.BusinessException;
import com.yd.common.utils.RedisUtil;
import com.yd.notice.feign.client.ApiNotificationTaskFeignClient; import com.yd.notice.feign.client.ApiNotificationTaskFeignClient;
import com.yd.notice.feign.request.ApiSendRequest; import com.yd.notice.feign.request.ApiSendRequest;
import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.schedulers.Schedulers; // 使用 RxJava3 的 Schedulers import io.reactivex.rxjava3.schedulers.Schedulers;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.TimeUnit;
@Slf4j @Slf4j
@Service @Service
public class ApiAiStreamServiceImpl implements ApiAiStreamService { public class ApiAiStreamServiceImpl implements ApiAiStreamService {
public static final String RESET = "\u001B[0m"; @Autowired
public static final String CYAN = "\u001B[36m"; private RedisUtil redisUtil;
public static final String GREEN = "\u001B[32m";
@Autowired @Autowired
private ApiSensitiveWordDetailService apiSensitiveWordDetailService; private ApiSensitiveWordDetailService apiSensitiveWordDetailService;
...@@ -38,22 +40,27 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService { ...@@ -38,22 +40,27 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService {
@Autowired @Autowired
private ApiNotificationTaskFeignClient apiNotificationTaskFeignClient; private ApiNotificationTaskFeignClient apiNotificationTaskFeignClient;
private static final String REDIS_KEY_PREFIX = "ai:stream:";
private static final int EXPIRE_SECONDS = 300; // 改为 int 类型,与 RedisUtil 匹配
/** /**
* 流式对话(SSE),保留敏感词检测逻辑 * 异步生成内容并存入 Redis
*/ */
@Override @Override
public Flux<String> streamChatWithSensitiveCheck(String question) { @Async("aiStreamExecutor")
// 1. 敏感词校验(与原非流式方法完全一致) public void generateAndStore(String sessionId, String question) {
String redisKey = REDIS_KEY_PREFIX + sessionId;
// 1. 敏感词校验
try { try {
apiSensitiveWordDetailService.checkWord(question); apiSensitiveWordDetailService.checkWord(question);
} catch (BusinessException e) { } catch (BusinessException e) {
int code = e.getCode(); int code = e.getCode();
String finalContent;
if (code == ResultCode.SENSITIVE_WORDS_EXIST.getCode()) { if (code == ResultCode.SENSITIVE_WORDS_EXIST.getCode()) {
log.info("检测到禁用敏感词,返回提示语"); finalContent = "抱歉,您输入的内容包含敏感词汇,无法为您提供服务。请调整后重新提问。";
return Flux.just("抱歉,您输入的内容包含敏感词汇,无法为您提供服务。请调整后重新提问。");
} else if (code == ResultCode.SENSITIVE_TZ_WORDS_EXIST.getCode()) { } else if (code == ResultCode.SENSITIVE_TZ_WORDS_EXIST.getCode()) {
log.info("检测到通知类型敏感词,发送企业微信通知"); // 发送通知
AuthUserDto authUserDto = SecurityUtil.getCurrentLoginUser(); AuthUserDto authUserDto = SecurityUtil.getCurrentLoginUser();
String userName = authUserDto.getUsername(); String userName = authUserDto.getUsername();
String params = "{\"customerName\":\"" + userName + "\"}"; String params = "{\"customerName\":\"" + userName + "\"}";
...@@ -63,10 +70,13 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService { ...@@ -63,10 +70,13 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService {
request.setReceiver("zxm|Sweet"); request.setReceiver("zxm|Sweet");
request.setParams(params); request.setParams(params);
apiNotificationTaskFeignClient.send(request); apiNotificationTaskFeignClient.send(request);
// 返回特殊标记,前端识别后展示产品列表 finalContent = "__SENSITIVE_NOTIFICATION__";
return Flux.just("__SENSITIVE_NOTIFICATION__"); } else {
finalContent = "系统错误";
} }
throw e; redisUtil.setCacheObject(redisKey, finalContent, EXPIRE_SECONDS, TimeUnit.SECONDS);
redisUtil.setCacheObject(redisKey + ":done", "true", EXPIRE_SECONDS, TimeUnit.SECONDS);
return;
} }
// 2. 正常调用大模型流式接口 // 2. 正常调用大模型流式接口
...@@ -85,52 +95,37 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService { ...@@ -85,52 +95,37 @@ public class ApiAiStreamServiceImpl implements ApiAiStreamService {
.model("qwen3-max") // 使用可用的模型 .model("qwen3-max") // 使用可用的模型
.messages(Arrays.asList(systemMsg, userMsg)) .messages(Arrays.asList(systemMsg, userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE) .resultFormat(GenerationParam.ResultFormat.MESSAGE)
.incrementalOutput(true) // 必须开启流式 .incrementalOutput(true)
.build(); .build();
return Flux.create(sink -> { try {
Publisher<GenerationResult> publisher = null; Publisher<GenerationResult> publisher = gen.streamCall(param);
try {
publisher = gen.streamCall(param);
} catch (NoApiKeyException | InputRequiredException e) {
log.error("流式调用初始化失败", e);
sink.error(e);
return;
}
Flowable<GenerationResult> flowable = Flowable.fromPublisher(publisher) Flowable<GenerationResult> flowable = Flowable.fromPublisher(publisher)
.subscribeOn(Schedulers.io()); .subscribeOn(Schedulers.io());
// 用于累计完整文本的 StringBuilder
StringBuilder fullContent = new StringBuilder(); StringBuilder fullContent = new StringBuilder();
flowable.subscribe( flowable.blockingSubscribe(
result -> { result -> {
String delta = result.getOutput().getChoices().get(0).getMessage().getContent(); String delta = result.getOutput().getChoices().get(0).getMessage().getContent();
// 拼接到累计文本中
fullContent.append(delta); fullContent.append(delta);
// 将当前完整内容发送给前端 // 实时写入 Redis
sink.next(fullContent.toString()); redisUtil.setCacheObject(redisKey, fullContent.toString(), EXPIRE_SECONDS, TimeUnit.SECONDS);
}, },
error -> { error -> {
log.error("流式调用出错", error); log.error("流式调用出错", error);
sink.error(error); redisUtil.setCacheObject(redisKey, "系统繁忙,请稍后重试", EXPIRE_SECONDS, TimeUnit.SECONDS);
redisUtil.setCacheObject(redisKey + ":done", "true", EXPIRE_SECONDS, TimeUnit.SECONDS);
}, },
() -> { () -> {
log.info("流式输出完成,总长度: {}", fullContent.length()); log.info("流式输出完成,sessionId: {}, 总长度: {}", sessionId, fullContent.length());
sink.complete(); redisUtil.setCacheObject(redisKey + ":done", "true", EXPIRE_SECONDS, TimeUnit.SECONDS);
} }
); );
}); } catch (NoApiKeyException | InputRequiredException e) {
} log.error("流式调用初始化失败", e);
redisUtil.setCacheObject(redisKey, "系统配置错误", EXPIRE_SECONDS, TimeUnit.SECONDS);
private static void printWithStyle(String text) { redisUtil.setCacheObject(redisKey + ":done", "true", EXPIRE_SECONDS, TimeUnit.SECONDS);
for (char ch : text.toCharArray()) {
if (ch == '|' || ch == '-' || ch == '=') {
System.out.print(CYAN + ch + RESET);
} else {
System.out.print(GREEN + ch + RESET);
}
} }
} }
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment