parent
4139769131
commit
a2a6e9ad2e
|
@ -18,7 +18,7 @@
|
|||
|
||||
<name>${project.artifactId}</name>
|
||||
<description>
|
||||
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。
|
||||
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。
|
||||
目前已接入各种模型,不限于:
|
||||
国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
|
||||
国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno
|
||||
|
|
|
@ -22,7 +22,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
|
|||
除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。
|
||||
"""),
|
||||
|
||||
AI_MIND_MAP_ROLE(2, "脑图助手", """
|
||||
AI_MIND_MAP_ROLE(2, "导图助手", """
|
||||
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
|
||||
# Geek-AI 助手
|
||||
## 完整的开源系统
|
||||
|
|
|
@ -45,9 +45,11 @@ public interface ErrorCodeConstants {
|
|||
// ========== API 音乐 1-040-006-000 ==========
|
||||
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
|
||||
|
||||
|
||||
// ========== API 写作 1-022-007-000 ==========
|
||||
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
|
||||
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!");
|
||||
|
||||
// ========== API 思维导图 1-040-008-000 ==========
|
||||
ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!");
|
||||
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
<name>${project.artifactId}</name>
|
||||
<description>
|
||||
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。
|
||||
ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。
|
||||
目前已接入各种模型,不限于:
|
||||
国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
|
||||
国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno
|
||||
|
|
|
@ -3,9 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
|
|||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天消息发送 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -5,10 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
|||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
|
@ -45,6 +42,13 @@ public class AiImageController {
|
|||
return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/public-page")
|
||||
@Operation(summary = "获取公开的绘图分页")
|
||||
public CommonResult<PageResult<AiImageRespVO>> getImagePagePublic(AiImagePublicPageReqVO pageReqVO) {
|
||||
PageResult<AiImageDO> pageResult = imageService.getImagePagePublic(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/get-my")
|
||||
@Operation(summary = "获取【我的】绘图记录")
|
||||
@Parameter(name = "id", required = true, description = "绘画编号", example = "1024")
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI 绘画公开的分页 Request VO")
|
||||
@Data
|
||||
public class AiImagePublicPageReqVO extends PageParam {
|
||||
|
||||
@Schema(description = "提示词")
|
||||
private String prompt;
|
||||
|
||||
}
|
|
@ -1,20 +1,25 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.mindmap;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.annotation.security.PermitAll;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
|
||||
|
||||
@Tag(name = "管理后台 - AI 思维导图")
|
||||
|
@ -26,10 +31,29 @@ public class AiMindMapController {
|
|||
private AiMindMapService mindMapService;
|
||||
|
||||
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
@Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快")
|
||||
@Operation(summary = "导图生成(流式)", description = "流式返回,响应较快")
|
||||
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
||||
public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) {
|
||||
return mindMapService.generateMindMap(generateReqVO, getLoginUserId());
|
||||
}
|
||||
|
||||
// ================ 导图管理 ================
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
@Operation(summary = "删除思维导图")
|
||||
@Parameter(name = "id", description = "编号", required = true)
|
||||
@PreAuthorize("@ss.hasPermission('ai:mind-map:delete')")
|
||||
public CommonResult<Boolean> deleteMindMap(@RequestParam("id") Long id) {
|
||||
mindMapService.deleteMindMap(id);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@GetMapping("/page")
|
||||
@Operation(summary = "获得思维导图分页")
|
||||
@PreAuthorize("@ss.hasPermission('ai:mind-map:query')")
|
||||
public CommonResult<PageResult<AiMindMapRespVO>> getMindMapPage(@Valid AiMindMapPageReqVO pageReqVO) {
|
||||
PageResult<AiMindMapDO> pageResult = mindMapService.getMindMapPage(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageResult, AiMindMapRespVO.class));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
import org.springframework.format.annotation.DateTimeFormat;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
|
||||
|
||||
@Schema(description = "管理后台 - AI 思维导图分页 Request VO")
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ToString(callSuper = true)
|
||||
public class AiMindMapPageReqVO extends PageParam {
|
||||
|
||||
@Schema(description = "用户编号", example = "4325")
|
||||
private Long userId;
|
||||
|
||||
@Schema(description = "生成内容提示", example = "Java 学习路线")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "创建时间")
|
||||
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
|
||||
private LocalDateTime[] createTime;
|
||||
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Schema(description = "管理后台 - AI 思维导图 Response VO")
|
||||
@Data
|
||||
public class AiMindMapRespVO {
|
||||
|
||||
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "3373")
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "4325")
|
||||
private Long userId;
|
||||
|
||||
@Schema(description = "生成内容提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "生成的思维导图内容")
|
||||
private String generatedContent;
|
||||
|
||||
@Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI")
|
||||
private String platform;
|
||||
|
||||
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo-0125")
|
||||
private String model;
|
||||
|
||||
@Schema(description = "错误信息")
|
||||
private String errorMessage;
|
||||
|
||||
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private LocalDateTime createTime;
|
||||
|
||||
}
|
|
@ -8,7 +8,6 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyRespV
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey;
|
||||
|
||||
import lombok.*;
|
||||
import java.util.*;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import org.springframework.format.annotation.DateTimeFormat;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI API 密钥分页 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI API 密钥 Response VO")
|
||||
@Data
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import java.util.*;
|
||||
import jakarta.validation.constraints.*;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI API 密钥新增/修改 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
|
||||
|
||||
import lombok.*;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - API 聊天模型分页 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -3,8 +3,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel;
|
|||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import jakarta.validation.constraints.*;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - API 聊天模型新增/修改 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
|
||||
|
||||
import lombok.*;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天角色分页 Request VO")
|
||||
@Data
|
||||
|
|
|
@ -3,8 +3,9 @@ package cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole;
|
|||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import jakarta.validation.constraints.*;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
import org.hibernate.validator.constraints.URL;
|
||||
|
||||
@Schema(description = "管理后台 - AI 聊天角色新增/修改 Request VO")
|
||||
|
|
|
@ -6,8 +6,6 @@ import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
|
|||
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.ToString;
|
||||
import org.springframework.format.annotation.DateTimeFormat;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.*;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
|
||||
/**
|
||||
* AI Chat 消息 DO
|
||||
|
|
|
@ -2,7 +2,9 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model;
|
|||
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.*;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.*;
|
||||
|
||||
/**
|
||||
|
|
|
@ -6,9 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|||
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
|
|
@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
|
@ -41,6 +42,13 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> {
|
|||
.orderByDesc(AiImageDO::getId));
|
||||
}
|
||||
|
||||
default PageResult<AiImageDO> selectPage(AiImagePublicPageReqVO pageReqVO) {
|
||||
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>()
|
||||
.eqIfPresent(AiImageDO::getPublicStatus, Boolean.TRUE)
|
||||
.likeIfPresent(AiImageDO::getPrompt, pageReqVO.getPrompt())
|
||||
.orderByDesc(AiImageDO::getId));
|
||||
}
|
||||
|
||||
default List<AiImageDO> selectListByStatusAndPlatform(Integer status, String platform) {
|
||||
return selectList(AiImageDO::getStatus, status,
|
||||
AiImageDO::getPlatform, platform);
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.mysql.mindmap;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
|
@ -11,4 +14,13 @@ import org.apache.ibatis.annotations.Mapper;
|
|||
*/
|
||||
@Mapper
|
||||
public interface AiMindMapMapper extends BaseMapperX<AiMindMapDO> {
|
||||
|
||||
default PageResult<AiMindMapDO> selectPage(AiMindMapPageReqVO reqVO) {
|
||||
return selectPage(reqVO, new LambdaQueryWrapperX<AiMindMapDO>()
|
||||
.eqIfPresent(AiMindMapDO::getUserId, reqVO.getUserId())
|
||||
.eqIfPresent(AiMindMapDO::getPrompt, reqVO.getPrompt())
|
||||
.betweenIfPresent(AiMindMapDO::getCreateTime, reqVO.getCreateTime())
|
||||
.orderByDesc(AiMindMapDO::getId));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatMode
|
|||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。
|
||||
* ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。
|
||||
* 目前已接入各种模型,不限于:
|
||||
* 国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek
|
||||
* 国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno
|
||||
|
|
|
@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
||||
|
||||
import java.util.List;
|
||||
|
|
|
@ -21,7 +21,10 @@ import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
|
|
|
@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
|||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
|
||||
|
@ -28,6 +29,14 @@ public interface AiImageService {
|
|||
*/
|
||||
PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 获取公开的绘图分页
|
||||
*
|
||||
* @param pageReqVO 分页条件
|
||||
* @return 绘图分页
|
||||
*/
|
||||
PageResult<AiImageDO> getImagePagePublic(AiImagePublicPageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 获得绘图记录
|
||||
*
|
||||
|
|
|
@ -14,6 +14,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
|
||||
|
@ -70,6 +71,11 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
return imageMapper.selectPageMy(userId, pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<AiImageDO> getImagePagePublic(AiImagePublicPageReqVO pageReqVO) {
|
||||
return imageMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiImageDO getImage(Long id) {
|
||||
return imageMapper.selectById(id);
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package cn.iocoder.yudao.module.ai.service.knowledge;
|
||||
|
||||
/**
|
||||
* AI 知识库 Service 接口
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
public interface DocService {
|
||||
|
||||
/**
|
||||
* 向量化文档
|
||||
*/
|
||||
void embeddingDoc();
|
||||
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package cn.iocoder.yudao.module.ai.service.knowledge;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AI 知识库 Service 实现类
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
//@Service // TODO 芋艿:临时注释,避免无法启动
|
||||
@Slf4j
|
||||
public class DocServiceImpl implements DocService {
|
||||
|
||||
@Resource
|
||||
private RedisVectorStore vectorStore;
|
||||
@Resource
|
||||
private TokenTextSplitter tokenTextSplitter;
|
||||
|
||||
// TODO @xin 临时测试用,后续删
|
||||
@Value("classpath:/webapp/test/Fel.pdf")
|
||||
private org.springframework.core.io.Resource data;
|
||||
|
||||
@Override
|
||||
public void embeddingDoc() {
|
||||
// 读取文件
|
||||
TikaDocumentReader loader = new TikaDocumentReader(data);
|
||||
List<Document> documents = loader.get();
|
||||
// 文档分段
|
||||
List<Document> segments = tokenTextSplitter.apply(documents);
|
||||
// 向量化并存储
|
||||
vectorStore.add(segments);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,7 +1,10 @@
|
|||
package cn.iocoder.yudao.module.ai.service.mindmap;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
|
@ -20,4 +23,19 @@ public interface AiMindMapService {
|
|||
*/
|
||||
Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId);
|
||||
|
||||
/**
|
||||
* 删除思维导图
|
||||
*
|
||||
* @param id 编号
|
||||
*/
|
||||
void deleteMindMap(Long id);
|
||||
|
||||
/**
|
||||
* 获得思维导图分页
|
||||
*
|
||||
* @param pageReqVO 分页查询
|
||||
* @return 思维导图分页
|
||||
*/
|
||||
PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO);
|
||||
|
||||
}
|
||||
|
|
|
@ -6,9 +6,11 @@ import cn.hutool.core.util.StrUtil;
|
|||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||
|
@ -33,8 +35,10 @@ import reactor.core.publisher.Flux;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_EXISTS;
|
||||
|
||||
/**
|
||||
* AI 思维导图 Service 实现类
|
||||
|
@ -57,10 +61,10 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
|||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
|
||||
// 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
|
||||
// 1. 获取导图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
|
||||
AiChatRoleDO role = CollUtil.getFirst(
|
||||
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
||||
// 1.1 获取脑图执行模型
|
||||
// 1.1 获取导图执行模型
|
||||
AiChatModelDO model = getModel(role);
|
||||
// 1.2 获取角色设定消息
|
||||
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
|
||||
|
@ -131,4 +135,23 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
|||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteMindMap(Long id) {
|
||||
// 校验存在
|
||||
validateMindMapExists(id);
|
||||
// 删除
|
||||
mindMapMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private void validateMindMapExists(Long id) {
|
||||
if (mindMapMapper.selectById(id) == null) {
|
||||
throw exception(MIND_MAP_NOT_EXISTS);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO) {
|
||||
return mindMapMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -9,8 +9,6 @@ import jakarta.validation.Valid;
|
|||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* AI 聊天模型 Service 接口
|
||||
*
|
||||
|
|
|
@ -21,7 +21,8 @@ import java.util.List;
|
|||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_ROLE_DISABLE;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_ROLE_NOT_EXISTS;
|
||||
|
||||
/**
|
||||
* AI 聊天角色 Service 实现类
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package cn.iocoder.yudao.module.ai.service.music;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicUpdateMyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicUpdateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
|
|
|
@ -104,14 +104,22 @@ xxl:
|
|||
|
||||
spring:
|
||||
ai:
|
||||
vectorstore: # 向量存储
|
||||
redis:
|
||||
index: default-index
|
||||
prefix: "default:"
|
||||
qianfan: # 文心一言
|
||||
api-key: x0cuLZ7XsaTCU08vuJWO87Lg
|
||||
secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
|
||||
zhipuai: # 智谱 AI
|
||||
api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs
|
||||
openai:
|
||||
openai: # OpenAI 官方
|
||||
api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
|
||||
base-url: https://api.gptsapi.net
|
||||
azure: # OpenAI 微软
|
||||
openai:
|
||||
endpoint: https://eastusprejade.openai.azure.com
|
||||
api-key: xxx
|
||||
ollama:
|
||||
base-url: http://127.0.0.1:11434
|
||||
chat:
|
||||
|
|
|
@ -23,12 +23,16 @@
|
|||
<artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
|
||||
|
@ -40,6 +44,30 @@
|
|||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 向量化,基于 Redis 存储,Tika 解析内容 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-transformers-spring-boot-starter</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-tika-document-reader</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-redis-store</artifactId>
|
||||
<version>${spring-ai.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- TODO @xin:引入我们项目的 starter -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.data</groupId>
|
||||
<artifactId>spring-data-redis</artifactId>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.cloud</groupId>
|
||||
<artifactId>yudao-common</artifactId>
|
||||
|
|
|
@ -10,11 +10,20 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
|||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||||
import org.springframework.ai.transformers.TransformersEmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
/**
|
||||
* 芋道 AI 自动配置
|
||||
|
@ -73,4 +82,36 @@ public class YudaoAiAutoConfiguration {
|
|||
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
|
||||
}
|
||||
|
||||
// ========== rag 相关 ==========
|
||||
@Bean
|
||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
||||
public EmbeddingModel transformersEmbeddingClient() {
|
||||
return new TransformersEmbeddingModel(MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
/**
|
||||
* 我们启动有加载很多 Embedding 模型,不晓得取哪个好,先 new 个 TransformersEmbeddingModel 跑
|
||||
*/
|
||||
@Bean
|
||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
||||
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
|
||||
RedisProperties redisProperties) {
|
||||
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||
.withIndexName(properties.getIndex())
|
||||
.withPrefix(properties.getPrefix())
|
||||
.build();
|
||||
|
||||
RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel,
|
||||
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||
properties.isInitializeSchema());
|
||||
redisVectorStore.afterPropertiesSet();
|
||||
return redisVectorStore;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Lazy // TODO 芋艿:临时注释,避免无法启动
|
||||
public TokenTextSplitter tokenTextSplitter() {
|
||||
return new TokenTextSplitter(500, 100, 5, 10000, true);
|
||||
}
|
||||
|
||||
}
|
|
@ -22,7 +22,8 @@ public enum AiPlatformEnum {
|
|||
|
||||
// ========== 国外平台 ==========
|
||||
|
||||
OPENAI("OpenAI", "OpenAI"),
|
||||
OPENAI("OpenAI", "OpenAI"), // OpenAI 官方
|
||||
AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软
|
||||
OLLAMA("Ollama", "Ollama"),
|
||||
|
||||
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
||||
|
|
|
@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
|||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.azure.ai.openai.OpenAIClient;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||
|
@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
|||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
|
@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return buildXingHuoChatModel(apiKey);
|
||||
case OPENAI:
|
||||
return buildOpenAiChatModel(apiKey, url);
|
||||
case AZURE_OPENAI:
|
||||
return buildAzureOpenAiChatModel(apiKey, url);
|
||||
case OLLAMA:
|
||||
return buildOllamaChatModel(url);
|
||||
default:
|
||||
|
@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return SpringUtil.getBean(XingHuoChatModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiChatModel.class);
|
||||
case AZURE_OPENAI:
|
||||
return SpringUtil.getBean(AzureOpenAiChatModel.class);
|
||||
case OLLAMA:
|
||||
return SpringUtil.getBean(OllamaChatModel.class);
|
||||
default:
|
||||
|
@ -179,7 +188,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
|
||||
*/
|
||||
private static TongYiChatModel buildTongYiChatModel(String key) {
|
||||
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
|
||||
Generation generation = SpringUtil.getBean(Generation.class);
|
||||
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
|
||||
// TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下
|
||||
// TODO @芋艿:貌似阿里云不是增量返回的
|
||||
|
@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return new OpenAiChatModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link AzureOpenAiAutoConfiguration}
|
||||
*/
|
||||
private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) {
|
||||
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
|
||||
// 创建 OpenAIClient 对象
|
||||
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
|
||||
connectionProperties.setApiKey(apiKey);
|
||||
connectionProperties.setEndpoint(url);
|
||||
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
|
||||
// 获取 AzureOpenAiChatProperties 对象
|
||||
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
|
||||
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration}
|
||||
*/
|
||||
|
|
|
@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
|
@ -35,6 +36,9 @@ public class AiUtils {
|
|||
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
|
||||
case OPENAI:
|
||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case AZURE_OPENAI:
|
||||
// TODO 芋艿:貌似没 model 字段???!
|
||||
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||
default:
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.autoconfigure.vectorstore.redis;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.RedisVectorStore;
|
||||
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
/**
|
||||
* TODO @xin 先拿 spring-ai 最新代码覆盖,1.0.0-M1 跟 redis 自动配置会冲突
|
||||
*
|
||||
* TODO 这个官方,有说啥时候 fix 哇?
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Eddú Meléndez
|
||||
*/
|
||||
@AutoConfiguration(after = RedisAutoConfiguration.class)
|
||||
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
|
||||
//@ConditionalOnBean(JedisConnectionFactory.class)
|
||||
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
|
||||
public class RedisVectorStoreAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
|
||||
JedisConnectionFactory jedisConnectionFactory) {
|
||||
|
||||
var config = RedisVectorStoreConfig.builder()
|
||||
.withIndexName(properties.getIndex())
|
||||
.withPrefix(properties.getPrefix())
|
||||
.build();
|
||||
|
||||
return new RedisVectorStore(config, embeddingModel,
|
||||
new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
|
||||
properties.isInitializeSchema());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,456 @@
|
|||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.vectorstore;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
import redis.clients.jedis.Pipeline;
|
||||
import redis.clients.jedis.json.Path2;
|
||||
import redis.clients.jedis.search.*;
|
||||
import redis.clients.jedis.search.Schema.FieldType;
|
||||
import redis.clients.jedis.search.schemafields.*;
|
||||
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
|
||||
|
||||
import java.text.MessageFormat;
|
||||
import java.util.*;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* The RedisVectorStore is for managing and querying vector data in a Redis database. It
|
||||
* offers functionalities like adding, deleting, and performing similarity searches on
|
||||
* documents.
|
||||
*
|
||||
* The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and
|
||||
* search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for
|
||||
* efficient similarity searches. Additionally, it allows for custom metadata fields in
|
||||
* the documents to be stored alongside the vector and content data.
|
||||
*
|
||||
* This class requires a RedisVectorStoreConfig configuration object for initialization,
|
||||
* which includes settings like Redis URI, index name, field names, and vector algorithms.
|
||||
* It also requires an EmbeddingModel to convert documents into embeddings before storing
|
||||
* them.
|
||||
*
|
||||
* @author Julien Ruaux
|
||||
* @author Christian Tzolov
|
||||
* @author Eddú Meléndez
|
||||
* @see VectorStore
|
||||
* @see RedisVectorStoreConfig
|
||||
* @see EmbeddingModel
|
||||
*/
|
||||
public class RedisVectorStore implements VectorStore, InitializingBean {
|
||||
|
||||
public enum Algorithm {
|
||||
|
||||
FLAT, HSNW
|
||||
|
||||
}
|
||||
|
||||
public record MetadataField(String name, FieldType fieldType) {
|
||||
|
||||
public static MetadataField text(String name) {
|
||||
return new MetadataField(name, FieldType.TEXT);
|
||||
}
|
||||
|
||||
public static MetadataField numeric(String name) {
|
||||
return new MetadataField(name, FieldType.NUMERIC);
|
||||
}
|
||||
|
||||
public static MetadataField tag(String name) {
|
||||
return new MetadataField(name, FieldType.TAG);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the Redis vector store.
|
||||
*/
|
||||
public static final class RedisVectorStoreConfig {
|
||||
|
||||
private final String indexName;
|
||||
|
||||
private final String prefix;
|
||||
|
||||
private final String contentFieldName;
|
||||
|
||||
private final String embeddingFieldName;
|
||||
|
||||
private final Algorithm vectorAlgorithm;
|
||||
|
||||
private final List<MetadataField> metadataFields;
|
||||
|
||||
private RedisVectorStoreConfig() {
|
||||
this(builder());
|
||||
}
|
||||
|
||||
private RedisVectorStoreConfig(Builder builder) {
|
||||
this.indexName = builder.indexName;
|
||||
this.prefix = builder.prefix;
|
||||
this.contentFieldName = builder.contentFieldName;
|
||||
this.embeddingFieldName = builder.embeddingFieldName;
|
||||
this.vectorAlgorithm = builder.vectorAlgorithm;
|
||||
this.metadataFields = builder.metadataFields;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start building a new configuration.
|
||||
* @return The entry point for creating a new configuration.
|
||||
*/
|
||||
public static Builder builder() {
|
||||
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* {@return the default config}
|
||||
*/
|
||||
public static RedisVectorStoreConfig defaultConfig() {
|
||||
|
||||
return builder().build();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String indexName = DEFAULT_INDEX_NAME;
|
||||
|
||||
private String prefix = DEFAULT_PREFIX;
|
||||
|
||||
private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME;
|
||||
|
||||
private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME;
|
||||
|
||||
private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
|
||||
|
||||
private List<MetadataField> metadataFields = new ArrayList<>();
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the Redis index name to use.
|
||||
* @param name the index name to use
|
||||
* @return this builder
|
||||
*/
|
||||
public Builder withIndexName(String name) {
|
||||
this.indexName = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the Redis key prefix to use (default: "embedding:").
|
||||
* @param prefix the prefix to use
|
||||
* @return this builder
|
||||
*/
|
||||
public Builder withPrefix(String prefix) {
|
||||
this.prefix = prefix;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the Redis content field name to use.
|
||||
* @param name the content field name to use
|
||||
* @return this builder
|
||||
*/
|
||||
public Builder withContentFieldName(String name) {
|
||||
this.contentFieldName = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the Redis embedding field name to use.
|
||||
* @param name the embedding field name to use
|
||||
* @return this builder
|
||||
*/
|
||||
public Builder withEmbeddingFieldName(String name) {
|
||||
this.embeddingFieldName = name;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures the Redis vector algorithmto use.
|
||||
* @param algorithm the vector algorithm to use
|
||||
* @return this builder
|
||||
*/
|
||||
public Builder withVectorAlgorithm(Algorithm algorithm) {
|
||||
this.vectorAlgorithm = algorithm;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMetadataFields(MetadataField... fields) {
|
||||
return withMetadataFields(Arrays.asList(fields));
|
||||
}
|
||||
|
||||
public Builder withMetadataFields(List<MetadataField> fields) {
|
||||
this.metadataFields = fields;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@return the immutable configuration}
|
||||
*/
|
||||
public RedisVectorStoreConfig build() {
|
||||
|
||||
return new RedisVectorStoreConfig(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private final boolean initializeSchema;
|
||||
|
||||
public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
|
||||
|
||||
public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
|
||||
|
||||
public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
|
||||
|
||||
public static final String DEFAULT_PREFIX = "embedding:";
|
||||
|
||||
public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
|
||||
|
||||
private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
|
||||
|
||||
private static final Path2 JSON_SET_PATH = Path2.of("$");
|
||||
|
||||
private static final String JSON_PATH_PREFIX = "$.";
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
|
||||
|
||||
private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
|
||||
|
||||
private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l);
|
||||
|
||||
private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
|
||||
|
||||
private static final String EMBEDDING_PARAM_NAME = "BLOB";
|
||||
|
||||
public static final String DISTANCE_FIELD_NAME = "vector_score";
|
||||
|
||||
private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
|
||||
|
||||
private final JedisPooled jedis;
|
||||
|
||||
private final EmbeddingModel embeddingModel;
|
||||
|
||||
private final RedisVectorStoreConfig config;
|
||||
|
||||
private FilterExpressionConverter filterExpressionConverter;
|
||||
|
||||
public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis,
|
||||
boolean initializeSchema) {
|
||||
|
||||
Assert.notNull(config, "Config must not be null");
|
||||
Assert.notNull(embeddingModel, "Embedding model must not be null");
|
||||
this.initializeSchema = initializeSchema;
|
||||
|
||||
this.jedis = jedis;
|
||||
this.embeddingModel = embeddingModel;
|
||||
this.config = config;
|
||||
this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
|
||||
}
|
||||
|
||||
public JedisPooled getJedis() {
|
||||
return this.jedis;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(List<Document> documents) {
|
||||
try (Pipeline pipeline = this.jedis.pipelined()) {
|
||||
for (Document document : documents) {
|
||||
var embedding = this.embeddingModel.embed(document);
|
||||
document.setEmbedding(embedding);
|
||||
|
||||
var fields = new HashMap<String, Object>();
|
||||
fields.put(this.config.embeddingFieldName, embedding);
|
||||
fields.put(this.config.contentFieldName, document.getContent());
|
||||
fields.putAll(document.getMetadata());
|
||||
pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields);
|
||||
}
|
||||
List<Object> responses = pipeline.syncAndReturnAll();
|
||||
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
|
||||
if (errResponse.isPresent()) {
|
||||
String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error(message);
|
||||
}
|
||||
throw new RuntimeException(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String key(String id) {
|
||||
return this.config.prefix + id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<Boolean> delete(List<String> idList) {
|
||||
try (Pipeline pipeline = this.jedis.pipelined()) {
|
||||
for (String id : idList) {
|
||||
pipeline.jsonDel(key(id));
|
||||
}
|
||||
List<Object> responses = pipeline.syncAndReturnAll();
|
||||
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
|
||||
if (errResponse.isPresent()) {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error("Could not delete document: {}", errResponse.get());
|
||||
}
|
||||
return Optional.of(false);
|
||||
}
|
||||
return Optional.of(true);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Document> similaritySearch(SearchRequest request) {
|
||||
|
||||
Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero");
|
||||
Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1,
|
||||
"The similarity score is bounded between 0 and 1; least to most similar respectively.");
|
||||
|
||||
String filter = nativeExpressionFilter(request);
|
||||
|
||||
String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName,
|
||||
EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
|
||||
|
||||
List<String> returnFields = new ArrayList<>();
|
||||
this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
|
||||
returnFields.add(this.config.embeddingFieldName);
|
||||
returnFields.add(this.config.contentFieldName);
|
||||
returnFields.add(DISTANCE_FIELD_NAME);
|
||||
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
|
||||
Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding))
|
||||
.returnFields(returnFields.toArray(new String[0]))
|
||||
.setSortBy(DISTANCE_FIELD_NAME, true)
|
||||
.dialect(2);
|
||||
|
||||
SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
|
||||
return result.getDocuments()
|
||||
.stream()
|
||||
.filter(d -> similarityScore(d) >= request.getSimilarityThreshold())
|
||||
.map(this::toDocument)
|
||||
.toList();
|
||||
}
|
||||
|
||||
private Document toDocument(redis.clients.jedis.search.Document doc) {
|
||||
var id = doc.getId().substring(this.config.prefix.length());
|
||||
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
|
||||
: null;
|
||||
Map<String, Object> metadata = this.config.metadataFields.stream()
|
||||
.map(MetadataField::name)
|
||||
.filter(doc::hasProperty)
|
||||
.collect(Collectors.toMap(Function.identity(), doc::getString));
|
||||
metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
|
||||
return new Document(id, content, metadata);
|
||||
}
|
||||
|
||||
private float similarityScore(redis.clients.jedis.search.Document doc) {
|
||||
return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2;
|
||||
}
|
||||
|
||||
private String nativeExpressionFilter(SearchRequest request) {
|
||||
if (request.getFilterExpression() == null) {
|
||||
return "*";
|
||||
}
|
||||
return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
|
||||
if (!this.initializeSchema) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If index already exists don't do anything
|
||||
if (this.jedis.ftList().contains(this.config.indexName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String response = this.jedis.ftCreate(this.config.indexName,
|
||||
FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields());
|
||||
if (!RESPONSE_OK.test(response)) {
|
||||
String message = MessageFormat.format("Could not create index: {0}", response);
|
||||
throw new RuntimeException(message);
|
||||
}
|
||||
}
|
||||
|
||||
private Iterable<SchemaField> schemaFields() {
|
||||
Map<String, Object> vectorAttrs = new HashMap<>();
|
||||
vectorAttrs.put("DIM", this.embeddingModel.dimensions());
|
||||
vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
|
||||
vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
|
||||
List<SchemaField> fields = new ArrayList<>();
|
||||
fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
|
||||
fields.add(VectorField.builder()
|
||||
.fieldName(jsonPath(this.config.embeddingFieldName))
|
||||
.algorithm(vectorAlgorithm())
|
||||
.attributes(vectorAttrs)
|
||||
.as(this.config.embeddingFieldName)
|
||||
.build());
|
||||
|
||||
if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
|
||||
for (MetadataField field : this.config.metadataFields) {
|
||||
fields.add(schemaField(field));
|
||||
}
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
private SchemaField schemaField(MetadataField field) {
|
||||
String fieldName = jsonPath(field.name);
|
||||
switch (field.fieldType) {
|
||||
case NUMERIC:
|
||||
return NumericField.of(fieldName).as(field.name);
|
||||
case TAG:
|
||||
return TagField.of(fieldName).as(field.name);
|
||||
case TEXT:
|
||||
return TextField.of(fieldName).as(field.name);
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
|
||||
}
|
||||
}
|
||||
|
||||
private VectorAlgorithm vectorAlgorithm() {
|
||||
if (config.vectorAlgorithm == Algorithm.HSNW) {
|
||||
return VectorAlgorithm.HNSW;
|
||||
}
|
||||
return VectorAlgorithm.FLAT;
|
||||
}
|
||||
|
||||
private String jsonPath(String field) {
|
||||
return JSON_PATH_PREFIX + field;
|
||||
}
|
||||
|
||||
private static float[] toFloatArray(List<Double> embeddingDouble) {
|
||||
float[] embeddingFloat = new float[embeddingDouble.size()];
|
||||
int i = 0;
|
||||
for (Double d : embeddingDouble) {
|
||||
embeddingFloat[i++] = d.floatValue();
|
||||
}
|
||||
return embeddingFloat;
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,70 @@
|
|||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import com.azure.ai.openai.OpenAIClient;
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
import com.azure.core.util.ClientOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME;
|
||||
|
||||
/**
|
||||
* {@link AzureOpenAiChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AzureOpenAIChatModelTests {
|
||||
|
||||
private final OpenAIClient openAiApi = (new OpenAIClientBuilder())
|
||||
.endpoint("https://eastusprejade.openai.azure.com")
|
||||
.credential(new AzureKeyCredential("xxx"))
|
||||
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"))
|
||||
.buildClient();
|
||||
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
|
||||
AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
|
@ -17,7 +16,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link XingHuoChatModel} 集成测试
|
||||
* {@link OpenAiChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue