Commit 511201a3 by 陈健智

ai搜索初版

parent 939698ae
......@@ -269,7 +269,7 @@
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>3.8.0</version>
<version>3.12.0</version>
</dependency>
<!-- dubbo -->
<dependency>
......@@ -324,6 +324,12 @@
<artifactId>ansj_seg</artifactId>
<version>5.0.2</version>
</dependency>
<!--火山引擎 豆包大模型-->
<dependency>
<groupId>com.volcengine</groupId>
<artifactId>volcengine-java-sdk-ark-runtime</artifactId>
<version>0.1.121</version>
</dependency>
</dependencies>
<build>
<plugins>
......
package com.zhiwei.brandkbs2.common;
import com.volcengine.ark.runtime.service.ArkService;
import com.zhiwei.brandkbs2.pojo.ai.AccessModel;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
/**
* @ClassName: DoubaoAIAccountFactor
* @Description DoubaoAIAccountFactor
* @author: sjj
* @date: 2024-07-24 14:01
*/
public class DoubaoAIAccountFactor {
// public static Account getPersonalAccount() {
// String apiKey = "1ff5c1f2-4fa0-4e40-8ea8-3c9674e05a33";
// List<AccessModel> modelList = new ArrayList<>();
// modelList.add(new AccessModel("ep-20240722180424-45xkt", AccessModel.Model.DOUBAO_PRO_4K));
// modelList.add(new AccessModel("ep-20240723162722-7p4xw", AccessModel.Model.DOUBAO_PRO_32K));
// modelList.add(new AccessModel("ep-20240723153621-6hd22", AccessModel.Model.DOUBAO_PRO_128K));
// modelList.add(new AccessModel("ep-20240723164110-qhtxm", AccessModel.Model.DOUBAO_LITE_4K));
// modelList.add(new AccessModel("ep-20240723164127-44ksx", AccessModel.Model.DOUBAO_LITE_32K));
// modelList.add(new AccessModel("ep-20240723164141-xxmhs", AccessModel.Model.DOUBAO_LITE_128K));
// return new Account(apiKey, modelList);
// }
public static Account getCompanyAccount() {
String apiKey = "607764fc-c9d9-47e4-a673-a310852917a0";
List<AccessModel> modelList = new ArrayList<>();
modelList.add(new AccessModel("ep-20240617061616-8d2ls", AccessModel.Model.DOUBAO_PRO_4K));
modelList.add(new AccessModel("ep-20240618021538-t6dpf", AccessModel.Model.DOUBAO_PRO_32K));
return new Account(apiKey, modelList);
}
@Data
@AllArgsConstructor
public static class Account {
String apiKey;
List<AccessModel> modelList;
}
public static ArkService arkService;
static {
arkService = new ArkService(getCompanyAccount().getApiKey());
}
}
......@@ -352,6 +352,12 @@ public class AppSearchController extends BaseController {
return ResponseResult.success(markDataService.getContendSearchCriteria(contendId));
}
@ApiOperation("搜索-AI搜索")
@GetMapping("/ai")
public ResponseResult getAISearchResult(@RequestParam(value = "question") String question) {
return ResponseResult.success(markDataService.getAISearchResult(question));
}
@ApiOperation("搜索-搜索关键词历史记录")
@GetMapping("/keyword/cache")
public ResponseResult getSearchKeywordCache(@ApiParam(name = "searchType",
......
package com.zhiwei.brandkbs2.es;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.zhiwei.brandkbs2.common.GenericAttribute;
import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.pojo.ChannelIndex;
import com.zhiwei.brandkbs2.pojo.ai.FieldMapping;
import com.zhiwei.brandkbs2.util.TextUtil;
import com.zhiwei.brandkbs2.util.Tools;
import lombok.Getter;
import lombok.Setter;
......@@ -30,9 +34,13 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.joda.time.Period;
import org.joda.time.PeriodType;
import org.springframework.beans.factory.annotation.Value;
......@@ -331,6 +339,56 @@ public class EsClientDao {
return Pair.of(new Long[]{startTime, endTime}, res);
}
public List<JSONObject> findSearch(List<FieldMapping> fieldMappings) throws IOException {
List<JSONObject> list = new ArrayList<>();
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
List<JSONObject> searchHits = searchScroll(query, 10000, new String[]{"id", GenericAttribute.ES_TIME, GenericAttribute.ES_IND_TITLE});
ImmutableMap<String, JSONObject> idMap = Maps.uniqueIndex(searchHits, json -> json.getString("id"));
Map<String, String> idTitle = searchHits.stream().filter(json -> Objects.nonNull(json.getString("ind_title"))).collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString("ind_title")));
if (idTitle.isEmpty()){
return list;
}
// 按标题聚合,取聚合结果集前9,并取结果集中最新的文章的id
List<String> idList = TextUtil.getKResult(idTitle).stream()
.sorted(Comparator.comparingInt(List<String>::size).reversed())
.limit(9)
.map(ids -> ids.stream().map(idMap::get).max(Comparator.comparingInt(json -> (int) json.getLongValue(GenericAttribute.ES_TIME))).orElse(null))
.filter(Objects::nonNull)
.map(json -> json.getString("id"))
.collect(Collectors.toList());
// 反查原数据
for (String id : idList) {
list.add(getTopTitleLatest(getBoolQueryBuilder(fieldMappings), id));
}
return list;
}
private JSONObject getTopTitleLatest(BoolQueryBuilder query, String id) throws IOException {
query.must(QueryBuilders.termQuery("id", id));
FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC);
SearchHits searchHits = searchHits(getIndexes(), query, null, null, sort, 0, 1, null);
return new JSONObject(searchHits.getAt(0).getSourceAsMap());
}
// private JSONObject getTopTitleLatest(BoolQueryBuilder query, String title) throws IOException {
// query.must(QueryBuilders.termQuery("agg_title.keyword", title));
// FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC);
// SearchHits searchHits = searchHits(getIndexes(), query, null, null, sort, 0, 1, null);
// return new JSONObject(searchHits.getAt(0).getSourceAsMap());
// }
private BoolQueryBuilder getBoolQueryBuilder(List<FieldMapping> fieldMappings) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
Map<String, List<FieldMapping>> groupMap = fieldMappings.stream().collect(Collectors.groupingBy(mapping -> mapping.getFieldMap().getFatherName()));
groupMap.forEach((fatherName, list) -> {
if (list.size() > 2) {
throw new IllegalStateException("构建搜索条件分组异常");
}
boolQueryBuilder.must(list.get(0).buildQuery(list.size() > 1 ? list.get(1) : null));
});
return boolQueryBuilder;
}
public String[] getIndexes() {
return getIndexList().toArray(new String[0]);
}
......
package com.zhiwei.brandkbs2.pojo.ai;
import lombok.Data;
import lombok.Getter;
/**
* @ClassName: AccessModel
* @Description AccessModel
* @author: sjj
* @date: 2024-07-24 14:05
*/
@Data
public class AccessModel {
// 模型名称
private String modelName;
// 接入点id
private String modelId;
// 输入单价(每千个token)
private Double inputPrice;
// 输出单价(每千个token)
private Double outputPrice;
public AccessModel(String modelId, Model model) {
this.modelId = modelId;
this.modelName = model.modelName;
this.inputPrice = model.inputPrice;
this.outputPrice = model.outputPrice;
}
public enum Model {
DOUBAO_PRO_4K("Doubao-pro-4k", 0.0008, 0.0020),
DOUBAO_PRO_32K("Doubao-pro-32k", 0.0008, 0.0020),
DOUBAO_PRO_128K("Doubao-pro-128k", 0.0090, 0.0500),
DOUBAO_LITE_4K("doubao-lite-4k", 0.0003, 0.0006),
DOUBAO_LITE_32K("Doubao-lite-32k", 0.0003, 0.0006),
DOUBAO_LITE_128K("Doubao-lite-128k", 0.0008, 0.0010);
@Getter
private final String modelName;
// 输入单价(每千个token)
@Getter
private final Double inputPrice;
// 输出单价(每千个token)
@Getter
private final Double outputPrice;
Model(String modelName, Double inputPrice, Double outputPrice) {
this.modelName = modelName;
this.inputPrice = inputPrice;
this.outputPrice = outputPrice;
}
}
}
package com.zhiwei.brandkbs2.pojo.ai;
import com.zhiwei.brandkbs2.common.GlobalPojo;
import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.pojo.AbstractProject;
import com.zhiwei.brandkbs2.pojo.Contend;
import com.zhiwei.brandkbs2.pojo.Project;
import lombok.Data;
import lombok.Getter;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.index.query.*;
import java.util.*;
import java.util.stream.Collectors;
/**
* @ClassName: FieldMap
* @Description FieldMap
* @author: sjj
* @date: 2024-08-02 14:01
*/
@Data
public class FieldMapping {
private FieldMap fieldMap;
private Object value;
public FieldMapping(FieldMap fieldMap, Object value) {
this.fieldMap = fieldMap;
this.value = value;
}
public QueryBuilder buildQuery(FieldMapping fieldMapping) {
boolean existsAnd = null != fieldMapping;
RangeQueryBuilder timeRangeBuilder;
String contendId = "0";
// 项目组需绑定查询
switch (fieldMap) {
case START_TIME:
timeRangeBuilder = QueryBuilders.rangeQuery(fieldMap.databaseName).gte(value);
if (existsAnd && fieldMapping.fieldMap.equals(FieldMap.END_TIME)) {
timeRangeBuilder.lt(fieldMapping.value);
}
return timeRangeBuilder;
case END_TIME:
timeRangeBuilder = QueryBuilders.rangeQuery(fieldMap.databaseName).lt(value);
if (existsAnd && fieldMapping.fieldMap.equals(FieldMap.START_TIME)) {
timeRangeBuilder.gte(fieldMapping.value);
}
return timeRangeBuilder;
case PROJECT:
if (existsAnd && fieldMapping.fieldMap == FieldMap.BRAND) {
contendId = (String) fieldMapping.value;
}
BoolQueryBuilder nestedBoolBuilder = QueryBuilders.boolQuery();
// 必要条件
nestedBoolBuilder.must(QueryBuilders.termQuery(fieldMap.databaseName, value + "_" + contendId));
return new NestedQueryBuilder("brandkbs_cache_maps", nestedBoolBuilder, ScoreMode.None);
case BRAND:
if (!existsAnd || fieldMapping.fieldMap != FieldMap.PROJECT) {
throw new IllegalStateException("项目条件缺失");
}
return fieldMapping.buildQuery(this);
case IND_FULL_TEXT:
return QueryBuilders.matchPhraseQuery(fieldMap.databaseName, value);
case SOURCE:
case MTAG:
return QueryBuilders.termQuery(fieldMap.databaseName, value);
}
return null;
}
public enum FieldMap {
START_TIME("起始时间", "时间", "time"),
END_TIME("结束时间", "时间", "time"),
PROJECT("项目", "项目", "brandkbs_cache_maps.key.keyword"),
BRAND("品牌", "项目", "brandkbs_cache_maps.key.keyword"),
SOURCE("渠道", "渠道", "source"),
MTAG("标签", "标签", "mark_cache_maps.name.keyword"),
IND_FULL_TEXT("搜索条件", "搜索条件", "ind_full_text");
@Getter
private final String name;
@Getter
private final String fatherName;
@Getter
private final String databaseName;
FieldMap(String name, String fatherName, String databaseName) {
this.name = name;
this.fatherName = fatherName;
this.databaseName = databaseName;
}
}
public static FieldMapping createFromNameAndValue(String name, Object value, List<FieldMapping> fieldMappings) {
FieldMap fieldMap = null;
// TODO 字段转换待完善,引入数据库
for (FieldMap f : FieldMap.values()) {
if (name.equals(f.getName())) {
// 项目名需要转成id
if (FieldMap.PROJECT == f) {
Map<String, Project> projectMap = GlobalPojo.PROJECT_MAP.values().stream().collect(Collectors.toMap(AbstractProject::getProjectName, o -> o));
if (projectMap.containsKey(String.valueOf(value))) {
value = projectMap.get(String.valueOf(value)).getId();
}
}
// 品牌需要转换
if (FieldMap.BRAND == f) {
if ("主品牌".equals(value)) {
value = Constant.PRIMARY_CONTEND_ID;
} else {
// 寻找对应的竞品id
Optional<FieldMapping> project = fieldMappings.stream().filter(field -> Objects.equals(FieldMap.PROJECT, field.getFieldMap())).findFirst();
if (project.isPresent()){
List<Contend> contendList = GlobalPojo.PROJECT_MAP.get(String.valueOf(project.get().getValue())).getContendList();
Object finalValue = value;
Optional<Contend> contendOptional = contendList.stream().filter(contend -> Objects.equals(contend.getBrandName(), finalValue)).findFirst();
if (contendOptional.isPresent()){
value = contendOptional.get().getId();
}else {
value = Constant.PRIMARY_CONTEND_ID;
}
}else {
value = Constant.PRIMARY_CONTEND_ID;
}
}
}
// 标签只包含正负中
if (FieldMap.MTAG == f) {
if (!Arrays.asList("正面", "中性", "负面").contains(String.valueOf(value))) {
return null;
}
}
fieldMap = f;
break;
}
}
if (null == fieldMap) {
return null;
}
return new FieldMapping(fieldMap, value);
}
}
......@@ -836,4 +836,6 @@ public interface MarkDataService {
* @return
*/
List<String> expandOriginRange(MarkSearchDTO dto);
JSONObject getAISearchResult(String question);
}
......@@ -6,13 +6,14 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.HanLP;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionResult;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.zhiwei.base.category.ClassB;
import com.zhiwei.base.entity.subclass.mark.MarkInfo;
import com.zhiwei.brandkbs2.auth.UserThreadLocal;
import com.zhiwei.brandkbs2.common.ChannelType;
import com.zhiwei.brandkbs2.common.GenericAttribute;
import com.zhiwei.brandkbs2.common.GlobalPojo;
import com.zhiwei.brandkbs2.common.RedisKeyPrefix;
import com.zhiwei.brandkbs2.common.*;
import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.dao.*;
import com.zhiwei.brandkbs2.easyexcel.EasyExcelUtil;
......@@ -29,6 +30,8 @@ import com.zhiwei.brandkbs2.listener.ApplicationProjectListener;
import com.zhiwei.brandkbs2.model.CommonCodeEnum;
import com.zhiwei.brandkbs2.model.ResponseResult;
import com.zhiwei.brandkbs2.pojo.*;
import com.zhiwei.brandkbs2.pojo.ai.AccessModel;
import com.zhiwei.brandkbs2.pojo.ai.FieldMapping;
import com.zhiwei.brandkbs2.pojo.dto.*;
import com.zhiwei.brandkbs2.pojo.vo.*;
import com.zhiwei.brandkbs2.service.*;
......@@ -83,6 +86,7 @@ import org.springframework.web.client.RestTemplate;
import javax.annotation.Resource;
import java.io.IOException;
import java.text.MessageFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.CompletableFuture;
......@@ -108,6 +112,33 @@ public class MarkDataServiceImpl implements MarkDataService {
private static final String XHS_PLATFORM_ID = "6433c2251701316728003be4";
private static final String QUESTION_PROMPT = "###\n" +
"假如你是专业的问题提炼人员,你将根据用户提供的内容,来提炼问题要素和条件。根据以下规则一步步执行:\n" +
"1.提及到年、月、日、天、礼拜的定义为时间要素,未提及到默认条件算近一周,条件给到具体的起始时间和结束时间(结束时间为当前时间则不用返回)的时间戳。\n" +
"2.提及到 XX 项目 XX 品牌的定义为项目及品牌要素,条件给到具体的项目及品牌(未提及品牌默认主品牌)。\n" +
"3.提及到 XX 渠道的定义为渠道要素,条件给到该渠道名。\n" +
"4.提及到正面、中性、负面的定义为标签要素,条件给到该标签名。\n" +
"5.提及到针对 XX ,针对 XX 相关或 XX 相关的定义为搜索条件要素(必须包含 针对/相关 字样),条件给到具体值\n" +
"5.时间和项目要素为必需要素,若不满足则返回“无法回答”。\n" +
"\n" +
"参考例子:\n" +
"示例 1:\n" +
"{用户:今年 7 月腾讯项目张清相关的正面数据}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000,\"结束时间\":1722355200000},\"项目\":\"腾讯\",\"品牌\":\"主品牌\",\"标签\":\"正面\",\"搜索条件\":\"张清\"}\n" +
"示例 2:\n" +
"{用户:近一个月老乡鸡竞品1品牌新浪网渠道数据}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000},\"项目\":\"老乡鸡\",\"品牌\":\"竞品1\",\"渠道\":\"新浪网\"}\n" +
"示例 3:\n" +
"{用户:近一年数据}\n" +
"输出:无法回答\n" +
"\n" +
"要求:\n" +
"1 按照指定输出格式输出。\n" +
"2 严格按照规则进行提炼。\n" +
"###";
private static final String RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的内容,给出自己的详细分析和见解。并注明理由(提供对应的数据文本1-{0})" +
"请分析:";
@Value("${istarshine.addIStarShineKSData.url}")
private String addIStarShineKSDataUrl;
......@@ -3919,6 +3950,92 @@ public class MarkDataServiceImpl implements MarkDataService {
return Collections.emptyList();
}
@Override
public JSONObject getAISearchResult(String question) {
JSONObject res = new JSONObject();
try {
// 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
// 根据AI生成条件
ChatCompletionResult result = standardRequest(question, modelName);
JSONObject json = JSON.parseObject((String) result.getChoices().get(0).getMessage().getContent());
// 数据条件
List<FieldMapping> filedMapping = getFiledMapping(json, question);
List<JSONObject> list = esClientDao.findSearch(filedMapping);
// AI回答
StringBuilder sb = new StringBuilder();
List<BaseMap> articles = list.stream().map(Tools::getBaseFromEsMap).collect(Collectors.toList());
int count = 1;
for (BaseMap baseMap : articles) {
String text = baseMap.getContent();
sb.append(count++).append("、").append(text).append(";");
}
String sbContent = sb.toString();
result = standardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question);
Object resultContent = result.getChoices().get(0).getMessage().getContent();
res.put("result", resultContent);
res.put("articles", articles);
}catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e);
}
return res;
}
private ChatCompletionResult standardRequest(String content, String modelName) {
return standardRequest(content, modelName, QUESTION_PROMPT);
}
private ChatCompletionResult standardRequest(String content, String modelName, String prompt) {
// DoubaoAIAccountFactor.Account account = DoubaoAIAccountFactor.getCompanyAccount();
AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName);
// ArkService service = new ArkService(account.getApiKey());
ChatCompletionResult chatCompletion = null;
try {
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(prompt).build();
final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
messages.add(systemMessage);
messages.add(userMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder().model(model.getModelId()).messages(messages).build();
chatCompletion = DoubaoAIAccountFactor.arkService.createChatCompletion(chatCompletionRequest);
log.info("chatCompletion-content:{}", JSONObject.toJSONString(chatCompletion.getChoices().get(0).getMessage().getContent()));
if (chatCompletion.getChoices().size() > 1) {
log.error("异常chatCompletion:{}", JSON.toJSONString(chatCompletion));
return null;
}
} catch (Exception e) {
log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(chatCompletion), e);
}
return chatCompletion;
}
/**
* 获取
* @param json
* @param content
* @return
*/
private static List<FieldMapping> getFiledMapping(JSONObject json, String content) {
List<FieldMapping> res = new ArrayList<>();
for (Map.Entry<String, Object> entry : json.entrySet()) {
if (entry.getValue() instanceof JSONObject) {
res.addAll(getFiledMapping((JSONObject) entry.getValue(), content));
} else {
// 文本限定关键字
if (entry.getKey().equals("搜索条件")) {
if (!(content.contains("针对") || content.contains("相关"))) {
continue;
}
}
FieldMapping fieldMapping = FieldMapping.createFromNameAndValue(entry.getKey(), entry.getValue(), res);
if (null != fieldMapping) {
res.add(fieldMapping);
}
}
}
return res;
}
/**
* 原发溯源大库es查询
* @param dto
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment