Commit 95a346f4 by zhangxingmin

push

parent 5ccf6816
package com.yd.ai.api.controller;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import io.reactivex.Flowable;
import com.yd.ai.api.service.ApiAiStreamService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.Arrays;
@RestController
@RequestMapping("/api/ai")
public class ApiAiStreamController {
@Autowired
private ApiAiStreamService apiAiStreamService;
/**
* 调用大模型接口获取AI输出的流信息
* @param question
* @return
*/
@CrossOrigin(origins = "*") // 开发时允许跨域
@GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamChat(@RequestParam String question) {
Generation gen = new Generation();
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
Message userMsg = Message.builder().role(Role.USER.getValue()).content(question).build();
GenerationParam param = GenerationParam.builder()
.apiKey("sk-d6551c67cfbe4a759a78dc3625729291")
.model("qwen-plus")
.messages(Arrays.asList(systemMsg, userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build();
Flowable<GenerationResult> flowable = null;
try {
flowable = gen.streamCall(param);
} catch (NoApiKeyException e) {
e.printStackTrace();
} catch (InputRequiredException e) {
e.printStackTrace();
}
// 将 RxJava Flowable 转为 Reactor Flux
return Flux.from(flowable)
.map(result -> {
String delta = result.getOutput().getChoices().get(0).getMessage().getContent();
return ServerSentEvent.builder(delta).build();
});
return apiAiStreamService.streamChat(question);
}
}
\ No newline at end of file
package com.yd.ai.api.service;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
public interface ApiAiStreamService {
Flux<ServerSentEvent<String>> streamChat(String question);
}
package com.yd.ai.api.service;
public interface ApiSensitiveWordDetailService {
void checkWord(String word);
}
package com.yd.ai.api.service.impl;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.yd.ai.api.service.ApiAiStreamService;
import com.yd.ai.api.service.ApiSensitiveWordDetailService;
import com.yd.auth.core.dto.AuthUserDto;
import com.yd.auth.core.utils.SecurityUtil;
import com.yd.common.enums.ResultCode;
import com.yd.common.exception.BusinessException;
import com.yd.notice.feign.client.ApiNotificationTaskFeignClient;
import com.yd.notice.feign.request.ApiSendRequest;
import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.Arrays;
@Slf4j
@Service
public class ApiAiStreamServiceImpl implements ApiAiStreamService {
@Autowired
private ApiSensitiveWordDetailService apiSensitiveWordDetailService;
@Autowired
private ApiNotificationTaskFeignClient apiNotificationTaskFeignClient;
/**
* 调用大模型接口获取AI输出的流信息
* @param question
* @return
*/
@Override
public Flux<ServerSentEvent<String>> streamChat(String question) {
// 敏感词校验
try {
apiSensitiveWordDetailService.checkWord(question);
} catch (BusinessException e) {
int code = e.getCode();
if (code == ResultCode.SENSITIVE_WORDS_EXIST.getCode()) {
// 禁用类型敏感词:返回温馨提示语,不调用大模型
log.info("检测到禁用敏感词,返回提示语,question: {}", question);
String tipMsg = "抱歉,您输入的内容包含敏感词汇,无法为您提供服务。请调整后重新提问。";
return Flux.just(ServerSentEvent.builder(tipMsg).build());
} else if (code == ResultCode.SENSITIVE_TZ_WORDS_EXIST.getCode()) {
// 通知类型敏感词:发送企业微信通知后继续抛出异常,让前端处理产品列表
log.info("检测到通知类型敏感词,发送企业微信通知,question: {}", question);
AuthUserDto authUserDto = SecurityUtil.getCurrentLoginUser();
String userName = authUserDto.getUsername();
String params = "{\"customerName\":\"" + userName + "\"}";
ApiSendRequest request = new ApiSendRequest();
request.setChannelBizId("wecom_default");
request.setChannelBizId("tpl_wecom_order");
request.setReceiver("zxm|Sweet");
request.setParams(params);
apiNotificationTaskFeignClient.send(request);
// 抛出异常,前端根据code=50002调用产品列表接口
throw new BusinessException(e.getCode(), e.getMsg());
} else {
// 其他异常继续抛出
throw e;
}
}
// 正常调用大模型流式接口
Generation gen = new Generation();
Message systemMsg = Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
Message userMsg = Message.builder().role(Role.USER.getValue()).content(question).build();
GenerationParam param = GenerationParam.builder()
.apiKey("sk-d6551c67cfbe4a759a78dc3625729291")
.model("qwen-plus")
.messages(Arrays.asList(systemMsg, userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build();
Flowable<GenerationResult> flowable = null;
try {
flowable = gen.streamCall(param);
} catch (NoApiKeyException e) {
log.error("NoApiKeyException: {}", e.getMessage());
return Flux.just(ServerSentEvent.builder("系统错误:API密钥配置异常").build());
} catch (InputRequiredException e) {
log.error("InputRequiredException: {}", e.getMessage());
return Flux.just(ServerSentEvent.builder("系统错误:输入参数异常").build());
}
// 将 RxJava Flowable 转为 Reactor Flux
return Flux.from(flowable)
.map(result -> {
String delta = result.getOutput().getChoices().get(0).getMessage().getContent();
return ServerSentEvent.builder(delta).build();
});
}
}
package com.yd.ai.api.service.impl;
import com.yd.ai.api.service.ApiSensitiveWordDetailService;
import com.yd.ai.service.dto.QuerySensitiveWordDetailDTO;
import com.yd.ai.service.service.ISensitiveWordDetailService;
import com.yd.common.enums.ResultCode;
import com.yd.common.exception.BusinessException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ApiSensitiveWordDetailServiceImpl implements ApiSensitiveWordDetailService {
@Autowired
private ISensitiveWordDetailService iSensitiveWordDetailService;
/**
* 调用敏感词库->校验敏感词
* @param word
* @return
*/
@Override
public void checkWord(String word) {
//查询禁用类型敏感词
long count = iSensitiveWordDetailService.countByCondition(QuerySensitiveWordDetailDTO.builder()
.wordLibBizId("LIB_200000") //通用类型的敏感词库
.word(word) //参与校验的长文本内容
.build());
if (count > 0) {
//存在禁用类型的敏感词,这个需要ai输出的时候温馨提示语。
throw new BusinessException(ResultCode.SENSITIVE_WORDS_EXIST.getCode(),ResultCode.SENSITIVE_WORDS_EXIST.getMessage());
}
//查询通知类型敏感词
long count1 = iSensitiveWordDetailService.countByCondition(QuerySensitiveWordDetailDTO.builder()
.wordLibBizId("LIB_200001")
.word(word)
.build());
if (count1 > 0) {
//存在通知类型的敏感词,需要前端调用产品购买列表的接口返回列表展示。
throw new BusinessException(ResultCode.SENSITIVE_TZ_WORDS_EXIST.getCode(),ResultCode.SENSITIVE_TZ_WORDS_EXIST.getMessage());
}
}
}
......@@ -26,5 +26,11 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</dependency>
<dependency>
<groupId>com.yd</groupId>
<artifactId>yd-notice-feign</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
</project>
\ No newline at end of file
package com.yd.ai.service.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class QuerySensitiveWordDetailDTO {
/**
* 敏感词库主表唯一业务ID
*/
private String wordLibBizId;
/**
* 敏感词内容
*/
private String word;
}
package com.yd.ai.service.service;
import com.yd.ai.service.dto.QuerySensitiveWordDetailDTO;
import com.yd.ai.service.model.SensitiveWordDetail;
import com.baomidou.mybatisplus.extension.service.IService;
import java.util.List;
/**
* <p>
* 敏感词明细表 服务类
......@@ -13,4 +16,7 @@ import com.baomidou.mybatisplus.extension.service.IService;
*/
public interface ISensitiveWordDetailService extends IService<SensitiveWordDetail> {
List<SensitiveWordDetail> queryList(QuerySensitiveWordDetailDTO dto);
long countByCondition(QuerySensitiveWordDetailDTO dto);
}
package com.yd.ai.service.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.yd.ai.service.dto.QuerySensitiveWordDetailDTO;
import com.yd.ai.service.model.SensitiveWordDetail;
import com.yd.ai.service.dao.SensitiveWordDetailMapper;
import com.yd.ai.service.service.ISensitiveWordDetailService;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.List;
/**
* <p>
......@@ -17,4 +22,51 @@ import org.springframework.stereotype.Service;
@Service
public class SensitiveWordDetailServiceImpl extends ServiceImpl<SensitiveWordDetailMapper, SensitiveWordDetail> implements ISensitiveWordDetailService {
/**
* 根据条件查询列表
* @param dto
* @return
*/
@Override
public List<SensitiveWordDetail> queryList(QuerySensitiveWordDetailDTO dto) {
LambdaQueryWrapper<SensitiveWordDetail> wrapper = new LambdaQueryWrapper<>();
// 只查询未删除的数据
wrapper.eq(SensitiveWordDetail::getIsDeleted, 0);
// 如果传入了词库ID,精确匹配
if (StringUtils.hasText(dto.getWordLibBizId())) {
wrapper.eq(SensitiveWordDetail::getWordLibBizId, dto.getWordLibBizId());
}
// 如果传入了长文本,查询出所有在该文本中出现的敏感词
if (StringUtils.hasText(dto.getWord())) {
// 使用 INSTR 函数判断敏感词是否在长文本中,参数使用占位符防止 SQL 注入
wrapper.apply("INSTR({0}, word) > 0", dto.getWord());
}
// 添加默认排序
wrapper.orderByAsc(SensitiveWordDetail::getSort)
.orderByDesc(SensitiveWordDetail::getCreateTime);
return this.list(wrapper);
}
/**
* 根据条件查询总数
* @param dto
* @return
*/
@Override
public long countByCondition(QuerySensitiveWordDetailDTO dto) {
LambdaQueryWrapper<SensitiveWordDetail> wrapper = new LambdaQueryWrapper<>();
// 只统计未删除的数据
wrapper.eq(SensitiveWordDetail::getIsDeleted, 0);
if (StringUtils.hasText(dto.getWordLibBizId())) {
wrapper.eq(SensitiveWordDetail::getWordLibBizId, dto.getWordLibBizId());
}
if (StringUtils.hasText(dto.getWord())) {
// 统计在长文本中出现的敏感词数量
wrapper.apply("INSTR({0}, word) > 0", dto.getWord());
}
return this.count(wrapper);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment