Commit 8c3c38a8 by 陈健智

ai搜索调整2

parent 0d50d0f7
...@@ -121,8 +121,13 @@ public class RedisKeyPrefix { ...@@ -121,8 +121,13 @@ public class RedisKeyPrefix {
public static final String CUSTOM_YUQING_ANALYZE_HIGH_WORD = "BRANDKBS:CUSTOM:YUQING:ANALYZE:HIGH:WORD:"; public static final String CUSTOM_YUQING_ANALYZE_HIGH_WORD = "BRANDKBS:CUSTOM:YUQING:ANALYZE:HIGH:WORD:";
/**
* 搜索相关缓存
*/
public static final String SEARCH_KEYWORD = "BRANDKBS:SEARCH:KEYWORD:"; public static final String SEARCH_KEYWORD = "BRANDKBS:SEARCH:KEYWORD:";
public static final String AI_SEARCH_QUESTION = "BRANDKBS:AI:SEARCH:QUESTION:";
public static String projectWarnHotTopKeyAll(String projectId, String type) { public static String projectWarnHotTopKeyAll(String projectId, String type) {
return RedisKeyPrefix.generateRedisKey(RedisKeyPrefix.PROJECT_WARN_HOT_TOP, projectId, Tools.concat(type, "*")); return RedisKeyPrefix.generateRedisKey(RedisKeyPrefix.PROJECT_WARN_HOT_TOP, projectId, Tools.concat(type, "*"));
} }
......
...@@ -358,13 +358,19 @@ public class AppSearchController extends BaseController { ...@@ -358,13 +358,19 @@ public class AppSearchController extends BaseController {
return ResponseResult.success(markDataService.getAISearchResult(question)); return ResponseResult.success(markDataService.getAISearchResult(question));
} }
@ApiOperation("搜索-AI参考提问") @ApiOperation("搜索-AI推荐提问")
@GetMapping("/ai/question") @GetMapping("/ai/question")
public ResponseResult getAIReferenceQuestion(@RequestParam(value = "question") String question, public ResponseResult getAIReferenceQuestion(@RequestParam(value = "question") String question,
@RequestParam(value = "size") int size) { @RequestParam(value = "size") int size) {
return ResponseResult.success(markDataService.getAIReferenceQuestion(question, size)); return ResponseResult.success(markDataService.getAIReferenceQuestion(question, size));
} }
@ApiOperation("搜索-AI参考提问")
@GetMapping("/ai/question-cache")
public ResponseResult getAIReferenceQuestion() {
return ResponseResult.success(markDataService.getAIReferenceQuestionCache(true));
}
@ApiOperation("搜索-搜索关键词历史记录") @ApiOperation("搜索-搜索关键词历史记录")
@GetMapping("/keyword/cache") @GetMapping("/keyword/cache")
public ResponseResult getSearchKeywordCache(@ApiParam(name = "searchType", public ResponseResult getSearchKeywordCache(@ApiParam(name = "searchType",
......
package com.zhiwei.brandkbs2.dao;
import com.zhiwei.brandkbs2.pojo.AISearchQuestionRecord;
import java.util.List;
/**
* @ClassName: AISearchQuestionRecordDao
* @Description AISearchQuestionRecordDao
* @author: cjz
* @date: 2024-08-12 17:07
*/
public interface AISearchQuestionRecordDao extends BaseMongoDao<AISearchQuestionRecord>{
List<String> findDistinctQuestion(String projectId);
}
package com.zhiwei.brandkbs2.dao.impl;
import com.zhiwei.brandkbs2.dao.AISearchQuestionRecordDao;
import com.zhiwei.brandkbs2.pojo.AISearchQuestionRecord;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* @ClassName: AISearchQuestionRecordDao
* @Description AISearchQuestionRecordDao
* @author: cjz
* @date: 2024-08-12 17:07
*/
@Component("aiSearchQuestionRecordDao")
public class AISearchQuestionRecordDaoImpl extends BaseMongoDaoImpl<AISearchQuestionRecord> implements AISearchQuestionRecordDao {
private static final String COLLECTION_NAME = "brandkbs_ai_search_question_record";
public AISearchQuestionRecordDaoImpl() {
super(COLLECTION_NAME);
}
@Override
public List<String> findDistinctQuestion(String projectId) {
Query query = new Query().addCriteria(Criteria.where("projectId").is(projectId)).with(Sort.by(Sort.Order.desc("cTime"))).limit(10);
return mongoTemplate.findDistinct(query, "question", COLLECTION_NAME, String.class);
}
}
package com.zhiwei.brandkbs2.es; package com.zhiwei.brandkbs2.es;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.ImmutableMap; import com.hankcs.hanlp.HanLP;
import com.google.common.collect.Maps;
import com.zhiwei.brandkbs2.common.GenericAttribute; import com.zhiwei.brandkbs2.common.GenericAttribute;
import com.zhiwei.brandkbs2.config.Constant; import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.pojo.ChannelIndex; import com.zhiwei.brandkbs2.pojo.ChannelIndex;
...@@ -34,9 +33,6 @@ import org.elasticsearch.index.query.QueryBuilders; ...@@ -34,9 +33,6 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.FieldSortBuilder;
...@@ -342,17 +338,32 @@ public class EsClientDao { ...@@ -342,17 +338,32 @@ public class EsClientDao {
public List<JSONObject> findSearch(List<FieldMapping> fieldMappings) throws IOException { public List<JSONObject> findSearch(List<FieldMapping> fieldMappings) throws IOException {
List<JSONObject> list = new ArrayList<>(); List<JSONObject> list = new ArrayList<>();
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings); BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
List<JSONObject> searchHits = searchScroll(query, 10000, new String[]{"id", GenericAttribute.ES_TIME, GenericAttribute.ES_IND_TITLE}); String[] fetchSource = {"id", GenericAttribute.ES_TIME, GenericAttribute.ES_IND_TITLE};
ImmutableMap<String, JSONObject> idMap = Maps.uniqueIndex(searchHits, json -> json.getString("id")); SearchHit[] hits = searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits();
Map<String, String> idTitle = searchHits.stream().filter(json -> Objects.nonNull(json.getString("ind_title"))).collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString("ind_title"))); // Map<String, JSONObject> idBaseMap = Arrays.stream(hits).map(hit -> new JSONObject(hit.getSourceAsMap())).collect(Collectors.toMap(json -> json.getString("id"), o -> o));
// Map<String, String> idTitle = Arrays.stream(hits)
// .map(hit -> new JSONObject(hit.getSourceAsMap()))
// .filter(json -> Objects.nonNull(json.getString(GenericAttribute.ES_IND_TITLE)))
// .collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString(GenericAttribute.ES_IND_TITLE)));
Pair<Map<String, JSONObject>, Map<String, String>> searchProcess = findSearchResultProcess(hits);
Map<String, JSONObject> idBaseMap = searchProcess.getLeft();
Map<String, String> idTitle = searchProcess.getRight();
// 搜索条件未找到结果,将搜索关键词分词处理,再次查询
if (idTitle.isEmpty()){
SearchHit[] searchHitHanLP = findSearchHanLP(fieldMappings, fetchSource);
Pair<Map<String, JSONObject>, Map<String, String>> searchProcessHanLP = findSearchResultProcess(searchHitHanLP);
idBaseMap = searchProcessHanLP.getLeft();
idTitle = searchProcessHanLP.getRight();
}
if (idTitle.isEmpty()){ if (idTitle.isEmpty()){
return list; return list;
} }
// 按标题聚合,取聚合结果集前9,并取结果集中最新的文章的id // 按标题聚合,取聚合结果集前9,并取结果集中最新的文章的id
Map<String, JSONObject> finalIdBaseMap = idBaseMap;
List<String> idList = TextUtil.getKResult(idTitle).stream() List<String> idList = TextUtil.getKResult(idTitle).stream()
.sorted(Comparator.comparingInt(List<String>::size).reversed()) .sorted(Comparator.comparing(List<String>::size, Comparator.reverseOrder()))
.limit(9) .limit(9)
.map(ids -> ids.stream().map(idMap::get).max(Comparator.comparingInt(json -> (int) json.getLongValue(GenericAttribute.ES_TIME))).orElse(null)) .map(ids -> ids.stream().map(finalIdBaseMap::get).max(Comparator.comparingLong(json -> json.getLongValue(GenericAttribute.ES_TIME))).orElse(null))
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(json -> json.getString("id")) .map(json -> json.getString("id"))
.collect(Collectors.toList()); .collect(Collectors.toList());
...@@ -363,6 +374,26 @@ public class EsClientDao { ...@@ -363,6 +374,26 @@ public class EsClientDao {
return list; return list;
} }
private Pair<Map<String, JSONObject>, Map<String, String>> findSearchResultProcess(SearchHit[] hits){
Map<String, JSONObject> idBaseMap = Arrays.stream(hits).map(hit -> new JSONObject(hit.getSourceAsMap())).collect(Collectors.toMap(json -> json.getString("id"), o -> o));
Map<String, String> idTitle = Arrays.stream(hits)
.map(hit -> new JSONObject(hit.getSourceAsMap()))
.filter(json -> Objects.nonNull(json.getString(GenericAttribute.ES_IND_TITLE)))
.collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString(GenericAttribute.ES_IND_TITLE)));
return Pair.of(idBaseMap, idTitle);
}
private SearchHit[] findSearchHanLP(List<FieldMapping> fieldMappings, String[] fetchSource) throws IOException {
fieldMappings.stream().filter(fieldMapping -> Objects.equals(FieldMapping.FieldMap.IND_FULL_TEXT, fieldMapping.getFieldMap()))
.findFirst().ifPresent(fieldMapping -> {
String value = String.valueOf(fieldMapping.getValue());
String newValue = HanLP.segment(Tools.filterSpecialCharacter(value)).stream().map(s -> s.word).distinct().collect(Collectors.joining("|"));
fieldMapping.setValue(newValue);
});
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
return searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits();
}
private JSONObject getTopTitleLatest(BoolQueryBuilder query, String id) throws IOException { private JSONObject getTopTitleLatest(BoolQueryBuilder query, String id) throws IOException {
query.must(QueryBuilders.termQuery("id", id)); query.must(QueryBuilders.termQuery("id", id));
FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC); FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC);
...@@ -370,13 +401,6 @@ public class EsClientDao { ...@@ -370,13 +401,6 @@ public class EsClientDao {
return new JSONObject(searchHits.getAt(0).getSourceAsMap()); return new JSONObject(searchHits.getAt(0).getSourceAsMap());
} }
// private JSONObject getTopTitleLatest(BoolQueryBuilder query, String title) throws IOException {
// query.must(QueryBuilders.termQuery("agg_title.keyword", title));
// FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC);
// SearchHits searchHits = searchHits(getIndexes(), query, null, null, sort, 0, 1, null);
// return new JSONObject(searchHits.getAt(0).getSourceAsMap());
// }
private BoolQueryBuilder getBoolQueryBuilder(List<FieldMapping> fieldMappings) { private BoolQueryBuilder getBoolQueryBuilder(List<FieldMapping> fieldMappings) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
Map<String, List<FieldMapping>> groupMap = fieldMappings.stream().collect(Collectors.groupingBy(mapping -> mapping.getFieldMap().getFatherName())); Map<String, List<FieldMapping>> groupMap = fieldMappings.stream().collect(Collectors.groupingBy(mapping -> mapping.getFieldMap().getFatherName()));
......
package com.zhiwei.brandkbs2.pojo;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
/**
* @ClassName: AISearchQuestionRecord
* @Description ai搜索问题记录
* @author: cjz
* @date: 2024-8-12 14:58
*/
@Getter
@Setter
@AllArgsConstructor
public class AISearchQuestionRecord extends AbstractBaseMongo {
/**
* 问题
*/
private String question;
/**
* 项目id
*/
private String projectId;
/**
* 创建时间
*/
private Long cTime;
}
...@@ -2,6 +2,7 @@ package com.zhiwei.brandkbs2.pojo.ai; ...@@ -2,6 +2,7 @@ package com.zhiwei.brandkbs2.pojo.ai;
import com.zhiwei.brandkbs2.common.GlobalPojo; import com.zhiwei.brandkbs2.common.GlobalPojo;
import com.zhiwei.brandkbs2.config.Constant; import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.es.EsQueryTools;
import com.zhiwei.brandkbs2.pojo.AbstractProject; import com.zhiwei.brandkbs2.pojo.AbstractProject;
import com.zhiwei.brandkbs2.pojo.Contend; import com.zhiwei.brandkbs2.pojo.Contend;
import com.zhiwei.brandkbs2.pojo.Project; import com.zhiwei.brandkbs2.pojo.Project;
...@@ -62,7 +63,8 @@ public class FieldMapping { ...@@ -62,7 +63,8 @@ public class FieldMapping {
} }
return fieldMapping.buildQuery(this); return fieldMapping.buildQuery(this);
case IND_FULL_TEXT: case IND_FULL_TEXT:
return QueryBuilders.matchPhraseQuery(fieldMap.databaseName, value); return EsQueryTools.assembleNormalKeywordQuery(String.valueOf(value), new String[]{fieldMap.databaseName});
// return QueryBuilders.matchPhraseQuery(fieldMap.databaseName, value);
case SOURCE: case SOURCE:
case MTAG: case MTAG:
return QueryBuilders.termQuery(fieldMap.databaseName, value); return QueryBuilders.termQuery(fieldMap.databaseName, value);
......
...@@ -2,10 +2,7 @@ package com.zhiwei.brandkbs2.service; ...@@ -2,10 +2,7 @@ package com.zhiwei.brandkbs2.service;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.zhiwei.brandkbs2.model.ResponseResult; import com.zhiwei.brandkbs2.model.ResponseResult;
import com.zhiwei.brandkbs2.pojo.BaseMap; import com.zhiwei.brandkbs2.pojo.*;
import com.zhiwei.brandkbs2.pojo.DailyReport;
import com.zhiwei.brandkbs2.pojo.Event;
import com.zhiwei.brandkbs2.pojo.MarkFlowEntity;
import com.zhiwei.brandkbs2.pojo.dto.*; import com.zhiwei.brandkbs2.pojo.dto.*;
import com.zhiwei.brandkbs2.pojo.vo.LineVO; import com.zhiwei.brandkbs2.pojo.vo.LineVO;
import com.zhiwei.brandkbs2.pojo.vo.PageVO; import com.zhiwei.brandkbs2.pojo.vo.PageVO;
...@@ -838,14 +835,21 @@ public interface MarkDataService { ...@@ -838,14 +835,21 @@ public interface MarkDataService {
List<String> expandOriginRange(MarkSearchDTO dto); List<String> expandOriginRange(MarkSearchDTO dto);
/** /**
* ai搜索-ai参考提问 * AI搜索-推荐提问
* @param question * @param question
* @return * @return
*/ */
List<String> getAIReferenceQuestion(String question, int size); List<String> getAIReferenceQuestion(String question, int size);
/** /**
* ai搜索-搜索结果 * AI搜索-参考提问
* @param cache
* @return
*/
List<String> getAIReferenceQuestionCache(boolean cache);
/**
* AI搜索-搜索结果
* @param question * @param question
* @return * @return
*/ */
......
...@@ -72,4 +72,9 @@ public interface TaskService{ ...@@ -72,4 +72,9 @@ public interface TaskService{
* 定时拉取并进行渠道库更新任务 * 定时拉取并进行渠道库更新任务
*/ */
void refreshChannelRecord(); void refreshChannelRecord();
/**
* 生成ai搜索参考提问缓存
*/
void cacheAIQuestion();
} }
...@@ -136,10 +136,13 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -136,10 +136,13 @@ public class MarkDataServiceImpl implements MarkDataService {
"1 按照指定输出格式输出。\n" + "1 按照指定输出格式输出。\n" +
"2 严格按照规则进行提炼。\n" + "2 严格按照规则进行提炼。\n" +
"###"; "###";
private static final String RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的内容,给出自己提炼的5个的详细分析和见解。并在每点分析后用数字表示注明1-{0}的参考文章,分析结果和参考文章之间严格用”|“作为分隔符进行分隔,并且多个参考文章之间也严格用”|“作为分隔符进行分隔,若没有对应的参考文章则无需返回,示例:分析结果。|1|2|3" + private static final String RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的内容,给出自己的详细分析和见解。并在每点分析后用数字表示注明1-{0}的参考文章,分析结果和参考文章之间严格用”|“作为分隔符进行分隔,并且多个参考文章之间也严格用”|“作为分隔符进行分隔,若没有对应的参考文章则无需返回,示例:分析结果。|1|2|3" +
"请分析:"; "请分析:";
private static final String REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,提出自己{0}个关于的{1}参考问题,每个问题给到准确的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" + private static final String REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,提出自己{0}个关于的{1}参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" +
"请提出:";
private static final String CACHE_REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,请参考给到的问题,提出自己5个类似的的参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" +
"请提出:"; "请提出:";
@Value("${istarshine.addIStarShineKSData.url}") @Value("${istarshine.addIStarShineKSData.url}")
...@@ -229,6 +232,9 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -229,6 +232,9 @@ public class MarkDataServiceImpl implements MarkDataService {
@Resource(name = "dailyReportDao") @Resource(name = "dailyReportDao")
DailyReportDao dailyReportDao; DailyReportDao dailyReportDao;
@Resource(name = "aiSearchQuestionRecordDao")
AISearchQuestionRecordDao aiSearchQuestionRecordDao;
@Resource(name = "toolsetServiceImpl") @Resource(name = "toolsetServiceImpl")
private ToolsetService toolsetService; private ToolsetService toolsetService;
...@@ -3959,9 +3965,8 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -3959,9 +3965,8 @@ public class MarkDataServiceImpl implements MarkDataService {
// 选用的模型名称 // 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName(); String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
String projectName = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId()).getProjectName(); String projectName = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId()).getProjectName();
ChatCompletionResult result = standardRequest(question, modelName, MessageFormat.format(REFERENCE_QUESTION_PROMPT, size, projectName)); String resultContent = standardRequest(question, modelName, MessageFormat.format(REFERENCE_QUESTION_PROMPT, size, projectName));
Object resultContent = result.getChoices().get(0).getMessage().getContent(); String[] splits = resultContent.split("\\|");
String[] splits = String.valueOf(resultContent).split("\\|");
return new ArrayList<>(Arrays.asList(splits)).stream().filter(StringUtils::isNoneBlank).map(String::trim).collect(Collectors.toList()); return new ArrayList<>(Arrays.asList(splits)).stream().filter(StringUtils::isNoneBlank).map(String::trim).collect(Collectors.toList());
}catch (Exception e){ }catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "获取ai参考提问异常-", e); ExceptionCast.cast(CommonCodeEnum.FAIL, "获取ai参考提问异常-", e);
...@@ -3970,17 +3975,60 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -3970,17 +3975,60 @@ public class MarkDataServiceImpl implements MarkDataService {
} }
@Override @Override
public List<String> getAIReferenceQuestionCache(boolean cache) {
String projectId = UserThreadLocal.getProjectId();
String key = RedisUtil.getAISearchQuestionCacheKey(projectId);
String resultStr;
// 返回缓存
if (cache && StringUtils.isNotEmpty(resultStr = redisUtil.get(key))) {
return JSONObject.parseArray(resultStr).toJavaList(String.class);
}
List<String> questionList = aiSearchQuestionRecordDao.findDistinctQuestion(projectId);
String projectName = GlobalPojo.PROJECT_MAP.get(projectId).getProjectName();
if (CollectionUtils.isEmpty(questionList)){
return getAIReferenceQuestionTemplate(projectName);
}
// 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
StringBuilder sb = new StringBuilder();
int count = 1;
for (String question : questionList) {
sb.append(count++).append("、").append(question).append(";");
}
String resultContent = standardRequest(sb.toString(), modelName, MessageFormat.format(CACHE_REFERENCE_QUESTION_PROMPT, projectName));
if (Objects.isNull(resultContent)){
return getAIReferenceQuestionTemplate(projectName);
}
String[] splits = resultContent.split("\\|");
List<String> result = new ArrayList<>(Arrays.asList(splits)).stream().filter(StringUtils::isNoneBlank).map(String::trim).collect(Collectors.toList());
redisUtil.setExpire(key, JSONObject.toJSONString(result));
return result;
}
private List<String> getAIReferenceQuestionTemplate(String projectName){
String question1 = MessageFormat.format("今年 7 月{0}项目{1}相关的正面数据", projectName, projectName);
String question2 = MessageFormat.format("近一个周{0}项目有发生哪些重大舆情", projectName);
String question3 = MessageFormat.format("{0}项目的竞品舆情有哪些", projectName);
String question4 = MessageFormat.format("{0}项目最近发生了哪些事件", projectName);
String question5 = MessageFormat.format("{0}项目近期负面报道有哪些,主要是那几家集中涉及哪些事件", projectName);
return Arrays.asList(question1, question2, question3, question4, question5);
}
@Override
public JSONObject getAISearchResult(String question) { public JSONObject getAISearchResult(String question) {
JSONObject res = new JSONObject(); JSONObject res = new JSONObject();
try { try {
// 选用的模型名称 // 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName(); String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
// 根据AI生成条件 // 根据AI生成条件
ChatCompletionResult result = standardRequest(question, modelName); String result = standardRequest(question, modelName, QUESTION_PROMPT);
JSONObject json = JSON.parseObject((String) result.getChoices().get(0).getMessage().getContent()); JSONObject json = JSON.parseObject(result);
// 数据条件 // 数据条件
List<FieldMapping> filedMapping = getFiledMapping(json, question); List<FieldMapping> filedMapping = getFiledMapping(json, question);
List<JSONObject> list = esClientDao.findSearch(filedMapping); List<JSONObject> list = esClientDao.findSearch(filedMapping);
if (CollectionUtils.isEmpty(list)){
return null;
}
// AI回答 // AI回答
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
List<BaseMap> articles = list.stream().map(Tools::getBaseFromEsMap).collect(Collectors.toList()); List<BaseMap> articles = list.stream().map(Tools::getBaseFromEsMap).collect(Collectors.toList());
...@@ -3990,9 +4038,8 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -3990,9 +4038,8 @@ public class MarkDataServiceImpl implements MarkDataService {
sb.append(count++).append("、").append(text).append(";"); sb.append(count++).append("、").append(text).append(";");
} }
String sbContent = sb.toString(); String sbContent = sb.toString();
result = standardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question); result = streamStandardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question);
Object resultContent = result.getChoices().get(0).getMessage().getContent(); String[] splits = result.split("\\r?\\n");
String[] splits = String.valueOf(resultContent).split("\\r?\\n");
List<JSONObject> answers = new ArrayList<>(); List<JSONObject> answers = new ArrayList<>();
for (int i = 0; i < splits.length; i++) { for (int i = 0; i < splits.length; i++) {
JSONObject answer = new JSONObject(); JSONObject answer = new JSONObject();
...@@ -4012,17 +4059,37 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4012,17 +4059,37 @@ public class MarkDataServiceImpl implements MarkDataService {
} }
res.put("answers", answers); res.put("answers", answers);
res.put("articles", articles); res.put("articles", articles);
// 记录返回成功的提问
aiSearchQuestionRecordDao.insertOne(new AISearchQuestionRecord(question, UserThreadLocal.getProjectId(), System.currentTimeMillis()));
return res;
}catch (Exception e){ }catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e); ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e);
} }
return res; return null;
} }
private ChatCompletionResult standardRequest(String content, String modelName) { private String streamStandardRequest(String content, String modelName, String prompt) {
return standardRequest(content, modelName, QUESTION_PROMPT); AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName);
StringBuilder result = new StringBuilder();
try {
final List<ChatMessage> streamMessages = new ArrayList<>();
final ChatMessage streamSystemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(prompt).build();
final ChatMessage streamUserMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
streamMessages.add(streamSystemMessage);
streamMessages.add(streamUserMessage);
ChatCompletionRequest streamChatCompletionRequest = ChatCompletionRequest.builder().model(model.getModelId()).messages(streamMessages).build();
DoubaoAIAccountFactor.arkService.streamChatCompletion(streamChatCompletionRequest).doOnError(Throwable::printStackTrace).blockingForEach(choice -> {
if (choice.getChoices().size() > 0) {
result.append(choice.getChoices().get(0).getMessage().getContent());
}
});
} catch (Exception e) {
log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(result), e);
}
return result.toString();
} }
private ChatCompletionResult standardRequest(String content, String modelName, String prompt) { private String standardRequest(String content, String modelName, String prompt) {
AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName); AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName);
ChatCompletionResult chatCompletion = null; ChatCompletionResult chatCompletion = null;
try { try {
...@@ -4037,10 +4104,11 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4037,10 +4104,11 @@ public class MarkDataServiceImpl implements MarkDataService {
log.error("异常chatCompletion:{}", JSON.toJSONString(chatCompletion)); log.error("异常chatCompletion:{}", JSON.toJSONString(chatCompletion));
return null; return null;
} }
return String.valueOf(chatCompletion.getChoices().get(0).getMessage().getContent());
} catch (Exception e) { } catch (Exception e) {
log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(chatCompletion), e); log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(chatCompletion), e);
} }
return chatCompletion; return null;
} }
/** /**
......
...@@ -89,6 +89,9 @@ public class TaskServiceImpl implements TaskService { ...@@ -89,6 +89,9 @@ public class TaskServiceImpl implements TaskService {
@Resource(name = "channelRecordRefreshTaskDao") @Resource(name = "channelRecordRefreshTaskDao")
private ChannelRecordRefreshTaskDao channelRecordRefreshTaskDao; private ChannelRecordRefreshTaskDao channelRecordRefreshTaskDao;
@Resource(name = "aiSearchQuestionRecordDao")
private AISearchQuestionRecordDao aiSearchQuestionRecordDao;
@Resource(name = "brandkbsTaskServiceImpl") @Resource(name = "brandkbsTaskServiceImpl")
BrandkbsTaskService brandkbsTaskService; BrandkbsTaskService brandkbsTaskService;
...@@ -486,6 +489,17 @@ public class TaskServiceImpl implements TaskService { ...@@ -486,6 +489,17 @@ public class TaskServiceImpl implements TaskService {
log.info("更新渠道库记录完成-taskId:{}", task.getId()); log.info("更新渠道库记录完成-taskId:{}", task.getId());
} }
@Override
public void cacheAIQuestion() {
AtomicInteger total = new AtomicInteger();
CompletableFuture.allOf(GlobalPojo.PROJECT_MAP.values().stream().map(project -> CompletableFuture.supplyAsync(() -> {
UserThreadLocal.set(new UserInfo().setProjectId(project.getId()));
markDataService.getAIReferenceQuestionCache(false);
log.info("项目:{}-{}-AI参考问题缓存完成:{}个", project.getProjectName(), project.getId(), total.incrementAndGet());
return null;
}, cacheServiceExecutor)).toArray(CompletableFuture[]::new)).join();
}
private void updateRefreshTask(String id, String status){ private void updateRefreshTask(String id, String status){
Update update = new Update(); Update update = new Update();
update.set("status", status); update.set("status", status);
......
...@@ -47,6 +47,7 @@ public class ControlCenter { ...@@ -47,6 +47,7 @@ public class ControlCenter {
// taskService.customEventCache(); // taskService.customEventCache();
taskService.eventAggTitleCache(); taskService.eventAggTitleCache();
taskService.yuqingAnalyzeHighWordCache(); taskService.yuqingAnalyzeHighWordCache();
taskService.cacheAIQuestion();
} catch (Exception e) { } catch (Exception e) {
log.error("定时按天缓存数据-出错", e); log.error("定时按天缓存数据-出错", e);
} finally { } finally {
......
...@@ -130,6 +130,10 @@ public class RedisUtil { ...@@ -130,6 +130,10 @@ public class RedisUtil {
return RedisKeyPrefix.SEARCH_KEYWORD + Tools.concat(projectId, userId, searchType); return RedisKeyPrefix.SEARCH_KEYWORD + Tools.concat(projectId, userId, searchType);
} }
public static String getAISearchQuestionCacheKey(String projectId){
return RedisKeyPrefix.AI_SEARCH_QUESTION + projectId;
}
public void setExpire(String key, String value, long timeout, TimeUnit unit) { public void setExpire(String key, String value, long timeout, TimeUnit unit) {
stringRedisTemplate.opsForValue().set(key, value, timeout, unit); stringRedisTemplate.opsForValue().set(key, value, timeout, unit);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment