【同步】BOOT 和 CLOUD 的功能(AI 知识库)

pull/143/MERGE
YunaiV 2024-10-02 14:27:52 +08:00
parent 5480caa3df
commit 5858aac8c3
31 changed files with 444 additions and 195 deletions

View File

@ -34,7 +34,12 @@ public enum AiChatRoleEnum {
### ###
### ###
"""); """),
AI_KNOWLEDGE_ROLE("知识库助手", """
{info},
""");
/** /**
* *

View File

@ -10,4 +10,7 @@ public class AiChatConversationCreateMyReqVO {
@Schema(description = "聊天角色编号", example = "666") @Schema(description = "聊天角色编号", example = "666")
private Long roleId; private Long roleId;
@Schema(description = "知识库编号", example = "1204")
private Long knowledgeId;
} }

View File

@ -21,6 +21,9 @@ public class AiChatConversationUpdateMyReqVO {
@Schema(description = "模型编号", example = "1") @Schema(description = "模型编号", example = "1")
private Long modelId; private Long modelId;
@Schema(description = "知识库编号", example = "1")
private Long knowledgeId;
@Schema(description = "角色设定", example = "一个快乐的程序员") @Schema(description = "角色设定", example = "一个快乐的程序员")
private String systemMessage; private String systemMessage;

View File

@ -1,12 +1,12 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge; package cn.iocoder.yudao.module.ai.controller.admin.knowledge;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeRespVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService; import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -28,24 +28,23 @@ public class AiKnowledgeController {
@Resource @Resource
private AiKnowledgeService knowledgeService; private AiKnowledgeService knowledgeService;
@GetMapping("/my-page") @GetMapping("/page")
@Operation(summary = "获取【我的】知识库分页") @Operation(summary = "获取知识库分页")
public CommonResult<PageResult<AiKnowledgeRespVO>> getKnowledgePageMy(@Validated PageParam pageReqVO) { public CommonResult<PageResult<AiKnowledgeRespVO>> getKnowledgePage(@Valid AiKnowledgePageReqVO pageReqVO) {
PageResult<AiKnowledgeDO> pageResult = knowledgeService.getKnowledgePageMy(getLoginUserId(), pageReqVO); PageResult<AiKnowledgeDO> pageResult = knowledgeService.getKnowledgePage(getLoginUserId(), pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class));
} }
@PostMapping("/create-my") @PostMapping("/create")
@Operation(summary = "创建【我的】知识库") @Operation(summary = "创建知识库")
public CommonResult<Long> createKnowledgeMy(@RequestBody @Valid AiKnowledgeCreateMyReqVO createReqVO) { public CommonResult<Long> createKnowledge(@RequestBody @Valid AiKnowledgeCreateReqVO createReqVO) {
return success(knowledgeService.createKnowledgeMy(createReqVO, getLoginUserId())); return success(knowledgeService.createKnowledge(createReqVO, getLoginUserId()));
} }
@PutMapping("/update-my") @PutMapping("/update")
@Operation(summary = "更新【我的】知识库") @Operation(summary = "更新知识库")
public CommonResult<Boolean> updateKnowledgeMy(@RequestBody @Valid AiKnowledgeUpdateMyReqVO updateReqVO) { public CommonResult<Boolean> updateKnowledge(@RequestBody @Valid AiKnowledgeUpdateReqVO updateReqVO) {
knowledgeService.updateKnowledgeMy(updateReqVO, getLoginUserId()); knowledgeService.updateKnowledge(updateReqVO, getLoginUserId());
return success(true); return success(true);
} }
} }

View File

@ -36,7 +36,7 @@ public class AiKnowledgeDocumentController {
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获取文档分页") @Operation(summary = "获取文档分页")
public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPageMy(@Valid AiKnowledgeDocumentPageReqVO pageReqVO) { public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPage(@Valid AiKnowledgeDocumentPageReqVO pageReqVO) {
PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO); PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class));
} }

View File

@ -29,7 +29,7 @@ public class AiKnowledgeSegmentController {
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获取段落分页") @Operation(summary = "获取段落分页")
public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPageMy(@Valid AiKnowledgeSegmentPageReqVO pageReqVO) { public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPage(@Valid AiKnowledgeSegmentPageReqVO pageReqVO) {
PageResult<AiKnowledgeSegmentDO> pageResult = segmentService.getKnowledgeSegmentPage(pageReqVO); PageResult<AiKnowledgeSegmentDO> pageResult = segmentService.getKnowledgeSegmentPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiKnowledgeSegmentRespVO.class)); return success(BeanUtils.toBean(pageResult, AiKnowledgeSegmentRespVO.class));
} }

View File

@ -7,9 +7,9 @@ import lombok.Data;
import java.util.List; import java.util.List;
@Schema(description = "管理后台 - AI 知识库创建【我的】 Request VO") @Schema(description = "管理后台 - AI 知识库创建 Request VO")
@Data @Data
public class AiKnowledgeCreateMyReqVO { public class AiKnowledgeCreateReqVO {
@Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南") @Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南")
@NotBlank(message = "知识库名称不能为空") @NotBlank(message = "知识库名称不能为空")
@ -18,11 +18,19 @@ public class AiKnowledgeCreateMyReqVO {
@Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "存储 ruoyi-vue-pro 操作文档") @Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "存储 ruoyi-vue-pro 操作文档")
private String description; private String description;
@Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "[1]") @Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "[1,2,3]")
private List<Long> visibilityPermissions; private List<Long> visibilityPermissions;
@Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
@NotNull(message = "嵌入模型不能为空") @NotNull(message = "嵌入模型不能为空")
private Long modelId; private Long modelId;
@Schema(description = "相似性阈值", requiredMode = Schema.RequiredMode.REQUIRED, example = "0.5")
@NotNull(message = "相似性阈值不能为空")
private Double similarityThreshold;
@Schema(description = "topK", requiredMode = Schema.RequiredMode.REQUIRED, example = "3")
@NotNull(message = "topK 不能为空")
private Integer topK;
} }

View File

@ -23,4 +23,24 @@ public class AiKnowledgeDocumentCreateReqVO {
@URL(message = "文档 URL 格式不正确") @URL(message = "文档 URL 格式不正确")
private String url; private String url;
@Schema(description = "每个段落的目标 token 数", requiredMode = Schema.RequiredMode.REQUIRED, example = "800")
@NotNull(message = "每个段落的目标 token 数不能为空")
private Integer defaultSegmentTokens;
@Schema(description = "每个段落的最小字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "350")
@NotNull(message = "每个段落的最小字符数不能为空")
private Integer minSegmentWordCount;
@Schema(description = "丢弃阈值:低于此阈值的段落会被丢弃", requiredMode = Schema.RequiredMode.REQUIRED, example = "5")
@NotNull(message = "丢弃阈值不能为空")
private Integer minChunkLengthToEmbed;
@Schema(description = "最大段落数", requiredMode = Schema.RequiredMode.REQUIRED, example = "10000")
@NotNull(message = "最大段落数不能为空")
private Integer maxNumSegments;
@Schema(description = "分块是否保留分隔符", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
@NotNull(message = "分块是否保留分隔符不能为空")
private Boolean keepSeparator;
} }

View File

@ -0,0 +1,14 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Schema(description = "管理后台 - AI 知识库的分页 Request VO")
@Data
public class AiKnowledgePageReqVO extends PageParam {
@Schema(description = "知识库名称", example = "Java 开发手册")
private String name;
}

View File

@ -9,7 +9,7 @@ import java.util.List;
@Schema(description = "管理后台 - AI 知识库更新【我的】 Request VO") @Schema(description = "管理后台 - AI 知识库更新【我的】 Request VO")
@Data @Data
public class AiKnowledgeUpdateMyReqVO { public class AiKnowledgeUpdateReqVO {
@Schema(description = "对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1204") @Schema(description = "对话编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1204")
@NotNull(message = "知识库编号不能为空") @NotNull(message = "知识库编号不能为空")
@ -22,7 +22,7 @@ public class AiKnowledgeUpdateMyReqVO {
@Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "") @Schema(description = "知识库描述", requiredMode = Schema.RequiredMode.REQUIRED, example = "")
private String description; private String description;
@Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "[1]") @Schema(description = "可见权限,只能选择哪些人可见", requiredMode = Schema.RequiredMode.REQUIRED, example = "1,2,3")
private List<Long> visibilityPermissions; private List<Long> visibilityPermissions;
@Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "嵌入模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")

View File

@ -0,0 +1,17 @@
package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Schema(description = "管理后台 - AI 知识库段落召回 Request VO")
@Data
public class AiKnowledgeSegmentSearchReqVO {
@Schema(description = "知识库编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
private Long knowledgeId;
@Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线")
private String content;
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat; package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
@ -64,6 +65,13 @@ public class AiChatConversationDO extends BaseDO {
*/ */
private Long roleId; private Long roleId;
/**
*
* <p>
* {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/** /**
* *
* *

View File

@ -1,14 +1,19 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat; package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.KeySequence; import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.*; import lombok.*;
import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.MessageType;
import java.util.List;
/** /**
* AI Chat DO * AI Chat DO
* *
@ -66,6 +71,15 @@ public class AiChatMessageDO extends BaseDO {
*/ */
private Long roleId; private Long roleId;
/**
*
*
* {@link AiKnowledgeSegmentDO#getId()}
*/
@TableField(typeHandler = JacksonTypeHandler.class)
private List<Long> segmentIds;
/** /**
* *
*/ */

View File

@ -2,10 +2,10 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
@ -38,11 +38,13 @@ public class AiKnowledgeDO extends BaseDO {
* *
*/ */
private String description; private String description;
// TODO @新:如果全部可见,需要怎么设置?
/** /**
* , * ,
* <p>
* -1
*/ */
@TableField(typeHandler = JacksonTypeHandler.class) @TableField(typeHandler = LongListTypeHandler.class)
private List<Long> visibilityPermissions; private List<Long> visibilityPermissions;
/** /**
* *
@ -52,10 +54,21 @@ public class AiKnowledgeDO extends BaseDO {
* *
*/ */
private String model; private String model;
/**
* topK
*/
private Integer topK;
/**
*
*/
private Double similarityThreshold;
/** /**
* *
* <p> * <p>
* {@link CommonStatusEnum} * {@link CommonStatusEnum}
*/ */
private Integer status; private Integer status;
} }

View File

@ -23,7 +23,7 @@ public class AiKnowledgeDocumentDO extends BaseDO {
private Long id; private Long id;
/** /**
* *
* * <p>
* {@link AiKnowledgeDO#getId()} * {@link AiKnowledgeDO#getId()}
*/ */
private Long knowledgeId; private Long knowledgeId;
@ -40,13 +40,39 @@ public class AiKnowledgeDocumentDO extends BaseDO {
*/ */
private String url; private String url;
/** /**
* token * token
*/ */
private Integer tokens; private Integer tokens;
/** /**
* *
*/ */
private Integer wordCount; private Integer wordCount;
// ========== 自定义分段所用参数 ==========
// TODO @新3defaultChunkSize、defaultChunkSize、minChunkSizeChars、maxNumChunks 这几个字段的命名,可能要微信一起讨论下。尽量命名保持风格统一哈。
/**
* token
*/
private Integer defaultSegmentTokens;
/**
*
*/
private Integer minSegmentWordCount;
/**
*
*/
private Integer minChunkLengthToEmbed;
/**
*
*/
private Integer maxNumSegments;
/**
*
*/
private Boolean keepSeparator;
// ===================================
/** /**
* *
* <p> * <p>

View File

@ -28,13 +28,13 @@ public class AiKnowledgeSegmentDO extends BaseDO {
private String vectorId; private String vectorId;
/** /**
* *
* * <p>
* {@link AiKnowledgeDO#getId()} * {@link AiKnowledgeDO#getId()}
*/ */
private Long knowledgeId; private Long knowledgeId;
/** /**
* *
* * <p>
* {@link AiKnowledgeDocumentDO#getId()} * {@link AiKnowledgeDocumentDO#getId()}
*/ */
private Long documentId; private Long documentId;
@ -52,7 +52,7 @@ public class AiKnowledgeSegmentDO extends BaseDO {
private Integer tokens; private Integer tokens;
/** /**
* *
* * <p>
* {@link CommonStatusEnum} * {@link CommonStatusEnum}
*/ */
private Integer status; private Integer status;

View File

@ -1,10 +1,10 @@
package cn.iocoder.yudao.module.ai.dal.mysql.knowledge; package cn.iocoder.yudao.module.ai.dal.mysql.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@ -16,10 +16,11 @@ import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface AiKnowledgeMapper extends BaseMapperX<AiKnowledgeDO> { public interface AiKnowledgeMapper extends BaseMapperX<AiKnowledgeDO> {
default PageResult<AiKnowledgeDO> selectPageByMy(Long userId, PageParam pageReqVO) { default PageResult<AiKnowledgeDO> selectPage(Long userId, AiKnowledgePageReqVO pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiKnowledgeDO>() return selectPage(pageReqVO, new LambdaQueryWrapperX<AiKnowledgeDO>()
.eq(AiKnowledgeDO::getUserId, userId)
.eq(AiKnowledgeDO::getStatus, CommonStatusEnum.ENABLE.getStatus()) .eq(AiKnowledgeDO::getStatus, CommonStatusEnum.ENABLE.getStatus())
.likeIfPresent(AiKnowledgeDO::getName, pageReqVO.getName())
.and(e -> e.apply("FIND_IN_SET(" + userId + ",visibility_permissions)").or(m -> m.apply("FIND_IN_SET(-1,visibility_permissions)")))
.orderByDesc(AiKnowledgeDO::getId)); .orderByDesc(AiKnowledgeDO::getId));
} }
} }

View File

@ -7,6 +7,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import java.util.List;
/** /**
* AI - Mapper * AI - Mapper
* *
@ -22,4 +24,11 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX<AiKnowledgeSegment
.likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getKeyword()) .likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getKeyword())
.orderByDesc(AiKnowledgeSegmentDO::getId)); .orderByDesc(AiKnowledgeSegmentDO::getId));
} }
default List<AiKnowledgeSegmentDO> selectListByVectorIds(List<String> vectorIdList) {
return selectList(new LambdaQueryWrapperX<AiKnowledgeSegmentDO>()
.in(AiKnowledgeSegmentDO::getVectorId, vectorIdList)
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
} }

View File

@ -13,6 +13,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -22,6 +23,7 @@ import org.springframework.validation.annotation.Validated;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.List; import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList; import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
@ -45,6 +47,8 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiChatRoleService chatRoleService; private AiChatRoleService chatRoleService;
@Resource
private AiKnowledgeService knowledgeService;
@Override @Override
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) { public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
@ -56,9 +60,14 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
Assert.notNull(model, "必须找到默认模型"); Assert.notNull(model, "必须找到默认模型");
validateChatModel(model); validateChatModel(model);
// 1.3 校验知识库
if (Objects.nonNull(createReqVO.getKnowledgeId())) {
knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
}
// 2. 创建 AiChatConversationDO 聊天对话 // 2. 创建 AiChatConversationDO 聊天对话
AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false) AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
.setModelId(model.getId()).setModel(model.getModel()) .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
.setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts()); .setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
if (role != null) { if (role != null) {
conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage()); conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());
@ -82,6 +91,11 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
model = chatModalService.validateChatModel(updateReqVO.getModelId()); model = chatModalService.validateChatModel(updateReqVO.getModelId());
} }
// 1.3 校验知识库是否存在
if (updateReqVO.getKnowledgeId() != null) {
knowledgeService.validateKnowledgeExists(updateReqVO.getKnowledgeId());
}
// 2. 更新对话信息 // 2. 更新对话信息
AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class); AiChatConversationDO updateObj = BeanUtils.toBean(updateReqVO, AiChatConversationDO.class);
if (Boolean.TRUE.equals(updateReqVO.getPinned())) { if (Boolean.TRUE.equals(updateReqVO.getPinned())) {

View File

@ -12,11 +12,15 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -30,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -62,6 +67,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
private AiChatModelService chatModalService; private AiChatModelService chatModalService;
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) { public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@ -83,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt // 3.2 召回段落
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt); ChatResponse chatResponse = chatModel.call(prompt);
// 3.3 段式返回 // 3.4 段式返回
String newContent = chatResponse.getResult().getOutput().getContent(); String newContent = chatResponse.getResult().getOutput().getContent();
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent)); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class)) return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)); .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
} }
@ -114,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model, AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); // 3.2 召回段落
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
// 3.3 构建 Prompt并进行调用
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatModel.stream(prompt); Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
// 3.3 流式返回 // 3.4 流式返回
// TODO 注意Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 // TODO 注意Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
@ -131,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).doOnComplete(() -> { }).doOnComplete(() -> {
// 忽略租户,因为 Flux 异步无法透传租户 // 忽略租户,因为 Flux 异步无法透传租户
TenantUtils.executeIgnore(() -> TenantUtils.executeIgnore(() ->
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString()))); chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
.setContent(contentBuffer.toString())));
}).doOnError(throwable -> { }).doOnError(throwable -> {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable); log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
// 忽略租户,因为 Flux 异步无法透传租户 // 忽略租户,因为 Flux 异步无法透传租户
@ -140,18 +155,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
if (Objects.isNull(knowledgeId)) {
return Collections.emptyList();
}
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) { AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定
// 1.1 召回内容消息构建
if (CollUtil.isNotEmpty(segmentList)) {
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
StringBuilder infoBuilder = StrUtil.builder();
segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
chatMessages.add(message);
}
// 1.2 system context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) { if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
chatMessages.add(new SystemMessage(conversation.getSystemMessage())); chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
} }
// 1.2 history message 历史消息 // 1.3 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO); List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent()))); contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
// 1.3 user message 新发送消息 // 1.4 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象
@ -163,12 +195,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
/** /**
* n * n
* * <p>
* n user + assistant * n user + assistant
* *
* @param messages * @param messages
* @param conversation * @param conversation
* @param sendReqVO * @param sendReqVO
* @return * @return
*/ */
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
@ -185,7 +217,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
} }
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1); AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId()) if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|| StrUtil.isEmpty(assistantMessage.getContent())) { || StrUtil.isEmpty(assistantMessage.getContent())) {
continue; continue;
} }
// 由于后续要 reverse 反转,所以先添加 assistantMessage // 由于后续要 reverse 反转,所以先添加 assistantMessage

View File

@ -9,15 +9,11 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum; import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
@ -48,24 +44,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource @Resource
private AiKnowledgeSegmentMapper segmentMapper; private AiKnowledgeSegmentMapper segmentMapper;
@Resource
private TokenTextSplitter tokenTextSplitter;
@Resource @Resource
private TokenCountEstimator tokenCountEstimator; private TokenCountEstimator tokenCountEstimator;
@Resource
private AiApiKeyService apiKeyService;
@Resource @Resource
private AiKnowledgeService knowledgeService; private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 0. 校验 // 0. 校验并获取向量存储实例
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); VectorStore vectorStore = knowledgeService.getVectorStoreById(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 1.1 下载文档 // 1.1 下载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl())); TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl()));
@ -82,6 +70,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
return documentId; return documentId;
} }
// 2 构造文本分段器
TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(createReqVO.getDefaultSegmentTokens(), createReqVO.getMinSegmentWordCount(), createReqVO.getMinChunkLengthToEmbed(),
createReqVO.getMaxNumSegments(), createReqVO.getKeepSeparator());
// 2.1 文档分段 // 2.1 文档分段
List<Document> segments = tokenTextSplitter.apply(documents); List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库 // 2.2 分段内容入库
@ -92,9 +83,7 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
.setStatus(CommonStatusEnum.ENABLE.getStatus())); .setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList); segmentMapper.insertBatch(segmentDOList);
// 3.1 获取向量存储实例 // 3. 向量化并存储
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.2 向量化并存储
segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId())); segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
vectorStore.add(segments); vectorStore.add(segments);
return documentId; return documentId;

View File

@ -2,10 +2,13 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import java.util.List;
/** /**
* AI Service * AI Service
* *
@ -35,4 +38,12 @@ public interface AiKnowledgeSegmentService {
*/ */
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO); void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
/**
*
*
* @param reqVO
* @return
*/
List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO);
} }

View File

@ -1,16 +1,34 @@
package cn.iocoder.yudao.module.ai.service.knowledge; package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
/** /**
* AI Service * AI Service
* *
@ -23,6 +41,13 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Resource @Resource
private AiKnowledgeSegmentMapper segmentMapper; private AiKnowledgeSegmentMapper segmentMapper;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
@Resource
private AiApiKeyService apiKeyService;
@Override @Override
public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) { public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
return segmentMapper.selectPage(pageReqVO); return segmentMapper.selectPage(pageReqVO);
@ -30,13 +55,80 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Override @Override
public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) { public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) {
segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class)); // 1. 校验
// TODO @xin 重新向量化 AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
// 2.1 获取知识库向量实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
// 2.2 删除原向量
vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
// 2.3 重新向量化
Document document = new Document(reqVO.getContent());
document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
vectorStore.add(List.of(document));
// 3. 更新段落内容
AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
knowledgeSegment.setVectorId(document.getId());
segmentMapper.updateById(knowledgeSegment);
} }
@Override @Override
public void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO) { public void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO) {
segmentMapper.updateById(BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class)); // 0 校验
// TODO @xin 1.禁用删除向量 2.启用重新向量化 AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
// 1 获取知识库向量实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
if (Objects.equals(reqVO.getStatus(), CommonStatusEnum.ENABLE.getStatus())) {
// 2.1 启用重新向量化
Document document = new Document(oldKnowledgeSegment.getContent());
document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
vectorStore.add(List.of(document));
knowledgeSegment.setVectorId(document.getId());
} else {
// 2.2 禁用删除向量
vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
knowledgeSegment.setVectorId("");
}
// 3 更新段落状态
segmentMapper.updateById(knowledgeSegment);
} }
@Override
public List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
// 1. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 2. 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.1 向量检索
List<Document> documentList = vectorStore.similaritySearch(SearchRequest.query(reqVO.getContent())
.withTopK(knowledge.getTopK())
.withSimilarityThreshold(knowledge.getSimilarityThreshold())
.withFilterExpression(new FilterExpressionBuilder().eq(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build()));
if (CollUtil.isEmpty(documentList)) {
return ListUtil.empty();
}
// 3.2 段落召回
return segmentMapper.selectListByVectorIds(CollUtil.getFieldValues(documentList, "id", String.class));
}
/**
*
*
* @param id
* @return
*/
private AiKnowledgeSegmentDO validateKnowledgeSegmentExists(Long id) {
AiKnowledgeSegmentDO knowledgeSegment = segmentMapper.selectById(id);
if (knowledgeSegment == null) {
throw exception(KNOWLEDGE_SEGMENT_NOT_EXISTS);
}
return knowledgeSegment;
}
} }

View File

@ -1,10 +1,11 @@
package cn.iocoder.yudao.module.ai.service.knowledge; package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import org.springframework.ai.vectorstore.VectorStore;
/** /**
* AI - Service * AI - Service
@ -14,23 +15,21 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
public interface AiKnowledgeService { public interface AiKnowledgeService {
/** /**
* *
* *
* @param createReqVO * @param createReqVO
* @param userId * @param userId
* @return * @return
*/ */
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId); Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId);
/** /**
* *
* *
* @param updateReqVO * @param updateReqVO
* @param userId * @param userId
*/ */
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId); void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId);
/** /**
* *
@ -40,11 +39,20 @@ public interface AiKnowledgeService {
AiKnowledgeDO validateKnowledgeExists(Long id); AiKnowledgeDO validateKnowledgeExists(Long id);
/** /**
* *
* *
* @param userId * @param userId
* @param pageReqVO * @param pageReqVO
* @return * @return
*/ */
PageResult<AiKnowledgeDO> getKnowledgePageMy(Long userId, PageParam pageReqVO); PageResult<AiKnowledgeDO> getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO);
/**
*
*
* @param id
* @return
*/
VectorStore getVectorStoreById(Long id);
} }

View File

@ -2,17 +2,19 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateMyReqVO; import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper; import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -27,16 +29,18 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_NOT_
@Slf4j @Slf4j
public class AiKnowledgeServiceImpl implements AiKnowledgeService { public class AiKnowledgeServiceImpl implements AiKnowledgeService {
@Resource
private AiChatModelService chatModalService;
@Resource @Resource
private AiKnowledgeMapper knowledgeMapper; private AiKnowledgeMapper knowledgeMapper;
@Resource
private AiChatModelService chatModelService;
@Resource
private AiApiKeyService apiKeyService;
@Override @Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) { public Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId) {
// 1. 校验模型配置 // 1. 校验模型配置
AiChatModelDO model = chatModalService.validateChatModel(createReqVO.getModelId()); AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getModelId());
// 2. 插入知识库 // 2. 插入知识库
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
@ -46,14 +50,14 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
} }
@Override @Override
public void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId) { public void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId) {
// 1.1 校验知识库存在 // 1.1 校验知识库存在
AiKnowledgeDO knowledgeBaseDO = validateKnowledgeExists(updateReqVO.getId()); AiKnowledgeDO knowledgeBaseDO = validateKnowledgeExists(updateReqVO.getId());
if (ObjUtil.notEqual(knowledgeBaseDO.getUserId(), userId)) { if (ObjUtil.notEqual(knowledgeBaseDO.getUserId(), userId)) {
throw exception(KNOWLEDGE_NOT_EXISTS); throw exception(KNOWLEDGE_NOT_EXISTS);
} }
// 1.2 校验模型配置 // 1.2 校验模型配置
AiChatModelDO model = chatModalService.validateChatModel(updateReqVO.getModelId()); AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getModelId());
// 2. 更新知识库 // 2. 更新知识库
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class); AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
@ -71,8 +75,16 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
} }
@Override @Override
public PageResult<AiKnowledgeDO> getKnowledgePageMy(Long userId, PageParam pageReqVO) { public PageResult<AiKnowledgeDO> getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO) {
return knowledgeMapper.selectPageByMy(userId, pageReqVO); return knowledgeMapper.selectPage(userId, pageReqVO);
}
@Override
public VectorStore getVectorStoreById(Long id) {
AiKnowledgeDO knowledge = validateKnowledgeExists(id);
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 创建或获取 VectorStore 对象
return apiKeyService.getOrCreateVectorStore(model.getKeyId());
} }
} }

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -39,8 +38,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource @Resource
private AiModelFactory modelFactory; private AiModelFactory modelFactory;
@Resource
private AiVectorStoreFactory vectorFactory;
@Override @Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) { public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@ -149,7 +146,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
public VectorStore getOrCreateVectorStore(Long id) { public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id); AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl()); // 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
} }
} }

View File

@ -2,8 +2,6 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@ -38,11 +36,6 @@ public class YudaoAiAutoConfiguration {
return new AiModelFactoryImpl(); return new AiModelFactoryImpl();
} }
@Bean
public AiVectorStoreFactory aiVectorFactory() {
return new AiVectorStoreFactoryImpl();
}
// ========== 各种 AI Client 创建 ========== // ========== 各种 AI Client 创建 ==========
@ -89,7 +82,7 @@ public class YudaoAiAutoConfiguration {
// TODO @xin 免费版本 // TODO @xin 免费版本
// @Bean // @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动」 // @Lazy // TODO 芋艿:临时注释,避免无法启动」
// public EmbeddingModel transformersEmbeddingClient() { // public TransformersEmbeddingModel transformersEmbeddingClient() {
// return new TransformersEmbeddingModel(MetadataMode.EMBED); // return new TransformersEmbeddingModel(MetadataMode.EMBED);
// } // }
@ -98,23 +91,24 @@ public class YudaoAiAutoConfiguration {
*/ */
// @Bean // @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动 // @Lazy // TODO 芋艿:临时注释,避免无法启动
// public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties, // public RedisVectorStore vectorStore(TransformersEmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
// RedisProperties redisProperties) { // RedisProperties redisProperties) {
// var config = RedisVectorStore.RedisVectorStoreConfig.builder() // var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(properties.getIndex()) // .withIndexName(properties.getIndex())
// .withPrefix(properties.getPrefix()) // .withPrefix(properties.getPrefix())
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
// .build(); // .build();
// //
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel, // RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), // new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// properties.isInitializeSchema()); // properties.isInitializeSchema());
// redisVectorStore.afterPropertiesSet(); // redisVectorStore.afterPropertiesSet();
// return redisVectorStore; // return redisVectorStore;
// } // }
@Bean @Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动 @Lazy // TODO 芋艿:临时注释,避免无法启动
public TokenTextSplitter tokenTextSplitter() { public TokenTextSplitter tokenTextSplitter() {
//TODO @xin 配置提取
return new TokenTextSplitter(500, 100, 5, 10000, true); return new TokenTextSplitter(500, 100, 5, 10000, true);
} }

View File

@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel; import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
/** /**
* AI Model * AI Model
@ -92,4 +93,17 @@ public interface AiModelFactory {
*/ */
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url); EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
/**
* VectorStore
* <p>
*
*
* @param embeddingModel
* @param platform
* @param apiKey API KEY
* @param url API URL
* @return VectorStore
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
} }

View File

@ -13,6 +13,7 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties; import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel; import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
@ -54,13 +55,18 @@ import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi; import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel; import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel; import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel; import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.retry.support.RetryTemplate; import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.search.Schema;
import java.util.List; import java.util.List;
@ -191,6 +197,25 @@ public class AiModelFactoryImpl implements AiModelFactory {
}); });
} }
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(cacheKey)
.withPrefix(prefix)
.withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
.build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
true);
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) { private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) { if (ArrayUtil.isEmpty(params)) {
return clazz.getName(); return clazz.getName();

View File

@ -1,28 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
// TODO @xin也放到 AiModelFactory 里面好了,后续改成 AiFactory
/**
* AI Vector
*
* @author xiaoxin
*/
public interface AiVectorStoreFactory {
/**
* VectorStore
* <p>
*
*
* @param embeddingModel
* @param platform
* @param apiKey API KEY
* @param url API URL
* @return VectorStore
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
}

View File

@ -1,52 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import redis.clients.jedis.JedisPooled;
/**
* AI Vector
* 使 redisVectorStore VectorStore
*
* @author xiaoxin
*/
public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢
// TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
// TODO 回复:好的哈
String index = "default-index";
String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(index)
.withPrefix(prefix)
.build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
true);
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
}
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
}