Commit 1190a500 by 陈健智

AI搜索增加辅助信息、联网搜索

parent 705b54ec
...@@ -396,9 +396,9 @@ public class AppDownloadController extends BaseController { ...@@ -396,9 +396,9 @@ public class AppDownloadController extends BaseController {
@PostMapping(value = "/contend/mark") @PostMapping(value = "/contend/mark")
@DownloadTask(taskName = "竞品库竞品舆情下载", description = "竞品库竞品舆情") @DownloadTask(taskName = "竞品库竞品舆情下载", description = "竞品库竞品舆情")
public ResponseResult exportContendMarkList(@RequestBody MarkSearchDTO markSearchDTO) { public ResponseResult exportContendMarkList(@RequestBody MarkSearchDTO markSearchDTO) {
if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){ // if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
Pair<String, List<ExportAppYuqingDTO>> stringListPair = markDataService.downloadContendMarkList(markSearchDTO); Pair<String, List<ExportAppYuqingDTO>> stringListPair = markDataService.downloadContendMarkList(markSearchDTO);
// excel写入至指定路径 // excel写入至指定路径
String projectName = projectService.getProjectById(UserThreadLocal.getProjectId()).getProjectName(); String projectName = projectService.getProjectById(UserThreadLocal.getProjectId()).getProjectName();
...@@ -412,9 +412,9 @@ public class AppDownloadController extends BaseController { ...@@ -412,9 +412,9 @@ public class AppDownloadController extends BaseController {
@LogRecord(description = "全网搜-舆情导出", values = {"startTime", "endTime", "fans", "filterType", "filterWords", "search", "keyword", "platforms", "sensitiveChannels", "sourceKeyword"}, entity = true, arguments = true) @LogRecord(description = "全网搜-舆情导出", values = {"startTime", "endTime", "fans", "filterType", "filterWords", "search", "keyword", "platforms", "sensitiveChannels", "sourceKeyword"}, entity = true, arguments = true)
@DownloadTask(taskName = "全网搜舆情下载", description = "全网搜舆情") @DownloadTask(taskName = "全网搜舆情下载", description = "全网搜舆情")
public ResponseResult exportSearchWhole(@RequestBody SearchFilterDTO dto) { public ResponseResult exportSearchWhole(@RequestBody SearchFilterDTO dto) {
if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){ // if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
// 针对商业数据库做限制 // 针对商业数据库做限制
if (dto.isExternalDataSource()) { if (dto.isExternalDataSource()) {
long time = DateUtils.addDays(Tools.truncDate(new Date(), Constant.DAY_PATTERN), -89).getTime(); long time = DateUtils.addDays(Tools.truncDate(new Date(), Constant.DAY_PATTERN), -89).getTime();
......
...@@ -161,9 +161,9 @@ public class AppSearchController extends BaseController { ...@@ -161,9 +161,9 @@ public class AppSearchController extends BaseController {
@LogRecord(values = {"fans", "sensitiveChannels:father,son", "keyword", "search"}, description = "全网搜", arguments = true, entity = true) @LogRecord(values = {"fans", "sensitiveChannels:father,son", "keyword", "search"}, description = "全网搜", arguments = true, entity = true)
@PostMapping("/searchWhole") @PostMapping("/searchWhole")
public ResponseResult searchWholeNetwork(@RequestBody SearchFilterDTO dto) { public ResponseResult searchWholeNetwork(@RequestBody SearchFilterDTO dto) {
if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){ // if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
cacheSearchKeyword(dto.getKeyword(), "whole"); cacheSearchKeyword(dto.getKeyword(), "whole");
// 针对商业数据库做限制 // 针对商业数据库做限制
if (dto.isExternalDataSource()) { if (dto.isExternalDataSource()) {
...@@ -200,9 +200,9 @@ public class AppSearchController extends BaseController { ...@@ -200,9 +200,9 @@ public class AppSearchController extends BaseController {
@PostMapping("/exportSearchWhole") @PostMapping("/exportSearchWhole")
@LogRecord(description = "全网搜-舆情导出", values = {"startTime", "endTime", "fans", "filterType", "filterWords", "search", "keyword", "platforms", "sensitiveChannels", "sourceKeyword"}, entity = true, arguments = true) @LogRecord(description = "全网搜-舆情导出", values = {"startTime", "endTime", "fans", "filterType", "filterWords", "search", "keyword", "platforms", "sensitiveChannels", "sourceKeyword"}, entity = true, arguments = true)
public ResponseResult exportSearchWhole(@RequestBody SearchFilterDTO dto) { public ResponseResult exportSearchWhole(@RequestBody SearchFilterDTO dto) {
if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){ // if (StringUtils.isNotEmpty(dto.getKeyword()) && Tools.checkUniteString(dto.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
// 针对商业数据库做限制 // 针对商业数据库做限制
if (dto.isExternalDataSource()) { if (dto.isExternalDataSource()) {
long time = DateUtils.addDays(Tools.truncDate(new Date(), Constant.DAY_PATTERN), -89).getTime(); long time = DateUtils.addDays(Tools.truncDate(new Date(), Constant.DAY_PATTERN), -89).getTime();
...@@ -231,9 +231,9 @@ public class AppSearchController extends BaseController { ...@@ -231,9 +231,9 @@ public class AppSearchController extends BaseController {
@LogRecord(values = {"searchType", "keyword"}, description = "查舆情", arguments = true, entity = true) @LogRecord(values = {"searchType", "keyword"}, description = "查舆情", arguments = true, entity = true)
@PostMapping("/mark/list") @PostMapping("/mark/list")
public ResponseResult getYuqingMarkList(@RequestBody MarkSearchDTO markSearchDTO) { public ResponseResult getYuqingMarkList(@RequestBody MarkSearchDTO markSearchDTO) {
if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){ // if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
cacheSearchKeyword(markSearchDTO.getKeyword(), "yuqing"); cacheSearchKeyword(markSearchDTO.getKeyword(), "yuqing");
PageVO<MarkFlowEntity> yuqingMarkList = markDataService.getYuqingMarkList(markSearchDTO); PageVO<MarkFlowEntity> yuqingMarkList = markDataService.getYuqingMarkList(markSearchDTO);
// 仅第一页增加平台进量(声量)统计 // 仅第一页增加平台进量(声量)统计
...@@ -339,9 +339,9 @@ public class AppSearchController extends BaseController { ...@@ -339,9 +339,9 @@ public class AppSearchController extends BaseController {
@LogRecord(values = "keyword", description = "查竞品",arguments = true, entity = true) @LogRecord(values = "keyword", description = "查竞品",arguments = true, entity = true)
@PostMapping("/contend/list") @PostMapping("/contend/list")
public ResponseResult getContendSearchList(@RequestBody MarkSearchDTO markSearchDTO) { public ResponseResult getContendSearchList(@RequestBody MarkSearchDTO markSearchDTO) {
if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){ // if (StringUtils.isNotEmpty(markSearchDTO.getKeyword()) && Tools.checkUniteString(markSearchDTO.getKeyword())){
return ResponseResult.failure("不支持特殊符号字段查询"); // return ResponseResult.failure("不支持特殊符号字段查询");
} // }
cacheSearchKeyword(markSearchDTO.getKeyword(), "contend"); cacheSearchKeyword(markSearchDTO.getKeyword(), "contend");
return ResponseResult.success(markDataService.getContendSearchList(markSearchDTO)); return ResponseResult.success(markDataService.getContendSearchList(markSearchDTO));
} }
...@@ -354,8 +354,17 @@ public class AppSearchController extends BaseController { ...@@ -354,8 +354,17 @@ public class AppSearchController extends BaseController {
@ApiOperation("搜索-AI搜索") @ApiOperation("搜索-AI搜索")
@GetMapping("/ai/answer") @GetMapping("/ai/answer")
public ResponseResult getAISearchResult(@RequestParam(value = "question") String question) { public ResponseResult getAISearchResult(@RequestParam(value = "question") String question,
return ResponseResult.success(markDataService.getAISearchResult(question)); @RequestParam(value = "keyword", required = false) String keyword,
@RequestParam(value = "startTime", required = false) Long startTime,
@RequestParam(value = "endTime", required = false) Long endTime) {
return ResponseResult.success(markDataService.getAISearchResult(question, keyword, startTime, endTime));
}
@ApiOperation("搜索-AI搜索-联网搜索")
@GetMapping("/ai/answer/online")
public ResponseResult getAIOnlineSearchResult(@RequestParam(value = "question") String question) {
return ResponseResult.success(markDataService.getAIOnlineSearchResult(question));
} }
@ApiOperation("搜索-AI推荐提问") @ApiOperation("搜索-AI推荐提问")
......
...@@ -2,9 +2,13 @@ package com.zhiwei.brandkbs2.es; ...@@ -2,9 +2,13 @@ package com.zhiwei.brandkbs2.es;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.hankcs.hanlp.HanLP; import com.hankcs.hanlp.HanLP;
import com.zhiwei.brandkbs2.auth.UserThreadLocal;
import com.zhiwei.brandkbs2.common.GenericAttribute; import com.zhiwei.brandkbs2.common.GenericAttribute;
import com.zhiwei.brandkbs2.common.GlobalPojo;
import com.zhiwei.brandkbs2.config.Constant; import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.pojo.ChannelIndex; import com.zhiwei.brandkbs2.pojo.ChannelIndex;
import com.zhiwei.brandkbs2.pojo.Contend;
import com.zhiwei.brandkbs2.pojo.Project;
import com.zhiwei.brandkbs2.pojo.ai.FieldMapping; import com.zhiwei.brandkbs2.pojo.ai.FieldMapping;
import com.zhiwei.brandkbs2.util.TextUtil; import com.zhiwei.brandkbs2.util.TextUtil;
import com.zhiwei.brandkbs2.util.Tools; import com.zhiwei.brandkbs2.util.Tools;
...@@ -335,6 +339,63 @@ public class EsClientDao { ...@@ -335,6 +339,63 @@ public class EsClientDao {
return Pair.of(new Long[]{startTime, endTime}, res); return Pair.of(new Long[]{startTime, endTime}, res);
} }
public List<JSONObject> findSearch(String question, String keyword, Long startTime, Long endTime) throws IOException {
List<JSONObject> list = new ArrayList<>();
String projectId = UserThreadLocal.getProjectId();
BoolQueryBuilder query = QueryBuilders.boolQuery();
// 默认一周
if (Objects.isNull(startTime) || Objects.isNull(endTime)){
endTime = System.currentTimeMillis();
startTime = System.currentTimeMillis() - Constant.ONE_WEEK;
}
// time
query.must(QueryBuilders.rangeQuery(GenericAttribute.ES_TIME).gte(startTime).lt(endTime));
// contend
Project project = GlobalPojo.PROJECT_MAP.get(projectId);
if (CollectionUtils.isNotEmpty(project.getContendList())) {
List<Contend> contendList = new ArrayList<>();
for (Contend contend : project.getContendList()) {
if (question.contains(contend.getBrandName())) {
contendList.add(contend);
}
}
if (CollectionUtils.isNotEmpty(contendList)) {
BoolQueryBuilder contendQuery = QueryBuilders.boolQuery();
for (Contend contend : contendList) {
contendQuery.should(EsQueryTools.assembleCacheMapsQuery(projectId, contend.getId()));
}
contendQuery.minimumShouldMatch(1);
query.must(contendQuery);
}else {
query.must(EsQueryTools.assembleCacheMapsQuery(projectId, Constant.PRIMARY_CONTEND_ID));
}
}
// keyword
query.must(EsQueryTools.assembleNormalKeywordQuery(keyword, new String[]{GenericAttribute.ES_IND_FULL_TEXT}));
String[] fetchSource = {"id", GenericAttribute.ES_TIME, GenericAttribute.ES_IND_TITLE};
// hit
SearchHit[] hits = searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits();
Pair<Map<String, JSONObject>, Map<String, String>> searchProcess = findSearchResultProcess(hits);
Map<String, JSONObject> idBaseMap = searchProcess.getLeft();
Map<String, String> idTitle = searchProcess.getRight();
if (idTitle.isEmpty()){
return list;
}
// 按标题聚合,取聚合结果集前9,并取结果集中最新的文章的id
List<String> idList = TextUtil.getKResult(idTitle).stream()
.sorted(Comparator.comparing(List<String>::size, Comparator.reverseOrder()))
.limit(9)
.map(ids -> ids.stream().map(idBaseMap::get).max(Comparator.comparingLong(json -> json.getLongValue(GenericAttribute.ES_TIME))).orElse(null))
.filter(Objects::nonNull)
.map(json -> json.getString("id"))
.collect(Collectors.toList());
// 反查原数据
for (String id : idList) {
list.add(getTopTitleLatest(id));
}
return list;
}
public List<JSONObject> findSearch(List<FieldMapping> fieldMappings) throws IOException { public List<JSONObject> findSearch(List<FieldMapping> fieldMappings) throws IOException {
List<JSONObject> list = new ArrayList<>(); List<JSONObject> list = new ArrayList<>();
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings); BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
...@@ -364,7 +425,7 @@ public class EsClientDao { ...@@ -364,7 +425,7 @@ public class EsClientDao {
.collect(Collectors.toList()); .collect(Collectors.toList());
// 反查原数据 // 反查原数据
for (String id : idList) { for (String id : idList) {
list.add(getTopTitleLatest(getBoolQueryBuilder(fieldMappings), id)); list.add(getTopTitleLatest(id));
} }
return list; return list;
} }
...@@ -389,11 +450,10 @@ public class EsClientDao { ...@@ -389,11 +450,10 @@ public class EsClientDao {
return searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits(); return searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits();
} }
private JSONObject getTopTitleLatest(BoolQueryBuilder query, String id) throws IOException { private JSONObject getTopTitleLatest(String id) throws IOException {
BoolQueryBuilder query = QueryBuilders.boolQuery();
query.must(QueryBuilders.termQuery("id", id)); query.must(QueryBuilders.termQuery("id", id));
FieldSortBuilder sort = new FieldSortBuilder(GenericAttribute.ES_TIME).order(SortOrder.DESC); return searchById(id);
SearchHits searchHits = searchHits(getIndexes(), query, null, null, sort, 0, 1, null);
return new JSONObject(searchHits.getAt(0).getSourceAsMap());
} }
private BoolQueryBuilder getBoolQueryBuilder(List<FieldMapping> fieldMappings) { private BoolQueryBuilder getBoolQueryBuilder(List<FieldMapping> fieldMappings) {
......
...@@ -62,4 +62,16 @@ public class UserLogRecord extends AbstractBaseMongo{ ...@@ -62,4 +62,16 @@ public class UserLogRecord extends AbstractBaseMongo{
record.setCost(cost); record.setCost(cost);
return record; return record;
} }
public static UserLogRecord defaultUserLogRecord(String description){
UserLogRecord record = new UserLogRecord();
record.setProjectId(UserThreadLocal.getProjectId());
record.setUserId(UserThreadLocal.getUserId());
record.setNickname(UserThreadLocal.getNickname());
record.setDescription(description);
record.setRoleId(UserThreadLocal.getRoleId());
record.setCTime(System.currentTimeMillis());
record.setUpdateTime(System.currentTimeMillis());
return record;
}
} }
...@@ -853,5 +853,12 @@ public interface MarkDataService { ...@@ -853,5 +853,12 @@ public interface MarkDataService {
* @param question * @param question
* @return * @return
*/ */
JSONObject getAISearchResult(String question); JSONObject getAISearchResult(String question, String keyword, Long startTime, Long endTime);
/**
* AI搜索-联网搜索
* @param question
* @return
*/
JSONObject getAIOnlineSearchResult(String question);
} }
...@@ -112,7 +112,7 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -112,7 +112,7 @@ public class MarkDataServiceImpl implements MarkDataService {
private static final String QUESTION_PROMPT = "###\n" + private static final String QUESTION_PROMPT = "###\n" +
"假如你是专业的问题提炼人员,你将根据用户提供的内容,来提炼问题要素和条件。根据以下规则一步步执行:\n" + "假如你是专业的问题提炼人员,你将根据用户提供的内容,来提炼问题要素和条件。根据以下规则一步步执行:\n" +
"1.提及到年、月、日、天、礼拜的定义为时间要素,未提及到默认条件算近一周,条件给到具体的起始时间和结束时间(结束时间为当前时间则不用返回)的时间戳。\n" + "1.提及到年、月、日、天、礼拜的定义为时间要素,未提及到默认条件算近一周,条件给到具体的起始时间和结束时间(结束时间为当前时间则不用返回)的时间戳。\n" +
"2.提及到 XX 渠道的定义为渠道要素,条件给到该渠道名。\n" + "2.提及到 XX 渠道的定义为渠道要素,条件给到该渠道名,注意:{0}不可作为渠道名。\n" +
"3.提及到正面、中性、负面的定义为标签要素,条件给到该标签名。\n" + "3.提及到正面、中性、负面的定义为标签要素,条件给到该标签名。\n" +
"4.提及到针对 XX ,针对 XX 相关或 XX 相关的定义为搜索条件要素(包含 针对/相关 字样)。未提及时按照文义在文段内容中提取的一个或多个词作为该文段的关键词,将其定义为搜索条件要素。多个关键词时每个关键词之间严格用” “(空格)作为分隔符进行分隔,单个关键词则无需分隔,关键词必须在给到的文段中出现,条件给到具体值\n" + "4.提及到针对 XX ,针对 XX 相关或 XX 相关的定义为搜索条件要素(包含 针对/相关 字样)。未提及时按照文义在文段内容中提取的一个或多个词作为该文段的关键词,将其定义为搜索条件要素。多个关键词时每个关键词之间严格用” “(空格)作为分隔符进行分隔,单个关键词则无需分隔,关键词必须在给到的文段中出现,条件给到具体值\n" +
"5.时间和关键词要素为必需要素,若不满足则返回“无法回答”。\n" + "5.时间和关键词要素为必需要素,若不满足则返回“无法回答”。\n" +
...@@ -132,8 +132,8 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -132,8 +132,8 @@ public class MarkDataServiceImpl implements MarkDataService {
"1 按照指定输出格式输出。\n" + "1 按照指定输出格式输出。\n" +
"2 严格按照规则进行提炼。\n" + "2 严格按照规则进行提炼。\n" +
"###"; "###";
private static final String RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的内容(无关内容则无需引用),贴合问题给出自己的详细分析和见解。并在每点分析后用数字表示注明1-{0" + private static final String RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的内容(无关内容则无需引用),贴合问题给出自己的一点或多点详细分析和见解,每个观点需同时提炼出一个贴合自己的详细分析和见解的小标题(无需注明“小标题”三个字)。并在每点详细分析和见解后用数字表示注明1-{0" +
"}的参考文章,分析结果和参考文章之间严格用”|“作为分隔符进行分隔,并且多个参考文章之间也严格用”|“作为分隔符进行分隔,若没有对应的参考文章则无需返回,示例:分析结果。|1|2|3" + "}的参考文章,分析结果和参考文章之间严格用”|“作为分隔符进行分隔(小标题与分析结果无需分隔),并且多个参考文章之间也严格用”|“作为分隔符进行分隔,若没有对应的参考文章则无需返回,示例:分析结果。|1|2|3" +
"请回答该问题:"; "请回答该问题:";
private static final String REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,提出自己{0}个关于的{1}参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" + private static final String REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,提出自己{0}个关于的{1}参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" +
...@@ -142,6 +142,9 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -142,6 +142,9 @@ public class MarkDataServiceImpl implements MarkDataService {
private static final String CACHE_REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,请参考给到的问题,提出自己5个类似的的参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" + private static final String CACHE_REFERENCE_QUESTION_PROMPT = "假如你是专业的问题提出人员,请参考给到的问题,提出自己5个类似的的参考问题,每个问题无需给到对应的序号,问题必须包含 针对/相关 字样,每个问题之间严格用”|“作为分隔符进行分隔。" +
"请提出:"; "请提出:";
private static final String ONLINE_RESULT_PROMPT = "假如你是专业的分析报告人员,你将根据用户提供的问题,贴合问题给出自己的一点或多点详细分析和见解,每个观点需同时提炼出一个贴合自己的详细分析和见解的小标题(无需注明“小标题”三个字)。每个观点间需要严格换行" +
"请回答该问题:";
@Value("${istarshine.addIStarShineKSData.url}") @Value("${istarshine.addIStarShineKSData.url}")
private String addIStarShineKSDataUrl; private String addIStarShineKSDataUrl;
...@@ -4028,26 +4031,40 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4028,26 +4031,40 @@ public class MarkDataServiceImpl implements MarkDataService {
} }
@Override @Override
public JSONObject getAISearchResult(String question) { public JSONObject getAISearchResult(String question, String keyword, Long startTime, Long endTime) {
JSONObject res = new JSONObject();
try { try {
// 选用的模型名称 // 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName(); String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
// 根据AI生成条件 List<JSONObject> list;
Pair<String, long[]> questionPair = standardRequest(question, modelName, QUESTION_PROMPT); Pair<String, long[]> questionPair = null;
if (Objects.isNull(questionPair)){ if (StringUtils.isNotBlank(keyword)){ // 已填辅助信息,则只用辅助信息
return null; keyword = Tools.canonicalKeyword(keyword);
list = esClientDao.findSearch(question, keyword, startTime, endTime);
}else { // 未填辅助信息,则根据AI生成条件
Project project = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId());
StringBuilder brandStr = new StringBuilder(project.getProjectName());
if (CollectionUtils.isNotEmpty(project.getContendList())){
project.getContendList().forEach(contend -> brandStr.append("、").append(contend.getBrandName()));
}
questionPair = standardRequest(question, modelName, MessageFormat.format(QUESTION_PROMPT, brandStr));
if (Objects.isNull(questionPair)) {
return null;
}
JSONObject json = JSON.parseObject(questionPair.getLeft());
// 数据条件
List<FieldMapping> filedMapping = getFiledMapping(json, question);
addDefaultFiledMapping(filedMapping, question);
list = esClientDao.findSearch(filedMapping);
} }
JSONObject json = JSON.parseObject(questionPair.getLeft());
// 数据条件
List<FieldMapping> filedMapping = getFiledMapping(json, question);
addDefaultFiledMapping(filedMapping, question);
List<JSONObject> list = esClientDao.findSearch(filedMapping);
String collectionName = userLogRecordDao.generateCollectionName(); String collectionName = userLogRecordDao.generateCollectionName();
if (CollectionUtils.isEmpty(list)){ if (CollectionUtils.isEmpty(list)){
// 需记录耗费 if (Objects.nonNull(questionPair)) {
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-无结果-提取搜索条件-" + question, // 需记录耗费
calculateCost(questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K)), collectionName); userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-无结果-提取搜索条件-" + question,
calculateCost(questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K)), collectionName);
}else {
userLogRecordDao.insertOne(UserLogRecord.defaultUserLogRecord("AI搜索-无结果-"+ question +"-辅助信息:"+ Tools.concatWithMinus(Arrays.asList(keyword, startTime, endTime))), collectionName);
}
return null; return null;
} }
// AI回答 // AI回答
...@@ -4060,32 +4077,17 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4060,32 +4077,17 @@ public class MarkDataServiceImpl implements MarkDataService {
} }
String sbContent = sb.toString(); String sbContent = sb.toString();
Pair<String, long[]> answerPair = streamStandardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question); Pair<String, long[]> answerPair = streamStandardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question);
String[] splits = answerPair.getLeft().split("\\r?\\n"); // 结果处理
List<JSONObject> answers = new ArrayList<>(); JSONObject res = aiResultDataProcess(answerPair, false);
for (int i = 0; i < splits.length; i++) {
JSONObject answer = new JSONObject();
String[] sonSplit = splits[i].split("\\|");
if (0 == sonSplit.length){
continue;
}
if (i == 0){
answer.put("answer", splits[i].trim());
answers.add(answer);
continue;
}
answer.put("answer", sonSplit[0]);
List<String> sonSplitList = new ArrayList<>(Arrays.asList(sonSplit)).stream().filter(StringUtils::isNoneBlank).skip(1).collect(Collectors.toList());
answer.put("referenceArticles", sonSplitList);
answers.add(answer);
}
res.put("answers", answers);
res.put("articles", articles); res.put("articles", articles);
res.put("searchCriteria", questionPair.getLeft()); res.put("searchCriteria", Objects.isNull(questionPair) ? Tools.concat(keyword, startTime, endTime) : questionPair.getLeft());
// 记录返回成功的提问 // 记录返回成功的提问
aiSearchQuestionRecordDao.insertOne(new AISearchQuestionRecord(question, UserThreadLocal.getProjectId(), System.currentTimeMillis())); aiSearchQuestionRecordDao.insertOne(new AISearchQuestionRecord(question, UserThreadLocal.getProjectId(), System.currentTimeMillis()));
// 需记录耗费 // 需记录耗费
double cost = calculateCost(questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K) + calculateCost(answerPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K); double cost = calculateCost(Objects.isNull(questionPair) ? null : questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K) + calculateCost(answerPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K);
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-结果-" + question, cost), collectionName); String description = "AI搜索-结果-" + question;
String extraDescription = "-辅助信息:" + Tools.concatWithMinus(Arrays.asList(keyword, startTime, endTime));
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost(Objects.nonNull(keyword) ? description + extraDescription : description, cost), collectionName);
return res; return res;
}catch (Exception e){ }catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e); ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e);
...@@ -4093,7 +4095,47 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4093,7 +4095,47 @@ public class MarkDataServiceImpl implements MarkDataService {
return null; return null;
} }
private JSONObject aiResultDataProcess(Pair<String, long[]> answerPair, boolean isOnline){
JSONObject res = new JSONObject();
// 结果处理
String[] splits = answerPair.getLeft().split("\\r?\\n");
List<JSONObject> answers = new ArrayList<>();
for (int i = 0; i < splits.length; i++) {
JSONObject answer = new JSONObject();
String[] sonSplit = splits[i].split("\\|");
if (0 == sonSplit.length){
continue;
}
if (i == 0 || isOnline){
answer.put("answer", splits[i].trim());
answers.add(answer);
continue;
}
if (StringUtils.isNotBlank(sonSplit[0])) {
answer.put("answer", sonSplit[0]);
List<String> sonSplitList = new ArrayList<>(Arrays.asList(sonSplit)).stream().skip(1).collect(Collectors.toList());
answer.put("referenceArticles", sonSplitList);
answers.add(answer);
}
}
res.put("answers", answers);
return res;
}
@Override
public JSONObject getAIOnlineSearchResult(String question) {
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
Pair<String, long[]> answerPair = streamStandardRequest(question, modelName, ONLINE_RESULT_PROMPT);
// 需记录耗费
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-联网搜索-" + question,
calculateCost(answerPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K)), userLogRecordDao.generateCollectionName());
return aiResultDataProcess(answerPair, true);
}
private double calculateCost(long[] tokens, AccessModel.Model model){ private double calculateCost(long[] tokens, AccessModel.Model model){
if (Objects.isNull(tokens)){
return 0d;
}
double inputCost = tokens[0] / 1000d * model.getInputPrice(); double inputCost = tokens[0] / 1000d * model.getInputPrice();
double outputCost = tokens[1] / 1000d * model.getOutputPrice(); double outputCost = tokens[1] / 1000d * model.getOutputPrice();
return inputCost + outputCost; return inputCost + outputCost;
...@@ -4166,14 +4208,11 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4166,14 +4208,11 @@ public class MarkDataServiceImpl implements MarkDataService {
if (entry.getValue() instanceof JSONObject) { if (entry.getValue() instanceof JSONObject) {
res.addAll(getFiledMapping((JSONObject) entry.getValue(), content)); res.addAll(getFiledMapping((JSONObject) entry.getValue(), content));
} else { } else {
// 文本限定关键字 if (Objects.isNull(entry.getValue())){
// if (entry.getKey().equals("搜索条件")) { continue;
// if (!(content.contains("针对") || content.contains("相关"))) { }
// continue;
// }
// }
FieldMapping fieldMapping = FieldMapping.createFromNameAndValue(entry.getKey(), entry.getValue(), content); FieldMapping fieldMapping = FieldMapping.createFromNameAndValue(entry.getKey(), entry.getValue(), content);
if (null != fieldMapping) { if (null != fieldMapping && Objects.nonNull(fieldMapping.getValue())) {
res.add(fieldMapping); res.add(fieldMapping);
} }
} }
...@@ -4181,24 +4220,30 @@ public class MarkDataServiceImpl implements MarkDataService { ...@@ -4181,24 +4220,30 @@ public class MarkDataServiceImpl implements MarkDataService {
return res; return res;
} }
private void addDefaultFiledMapping(List<FieldMapping> filedMapping, String question){ private void addDefaultFiledMapping(List<FieldMapping> fieldMappings, String question){
Project project = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId()); Project project = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId());
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.PROJECT, UserThreadLocal.getProjectId())); fieldMappings.add(new FieldMapping(FieldMapping.FieldMap.PROJECT, UserThreadLocal.getProjectId()));
List<String> projectBandNames = new ArrayList<>();
projectBandNames.add(project.getProjectName());
projectBandNames.add(project.getBrandName());
if (CollectionUtils.isNotEmpty(project.getContendList())){ if (CollectionUtils.isNotEmpty(project.getContendList())){
List<String> contends = new ArrayList<>(); List<String> contends = new ArrayList<>();
for (Contend contend : project.getContendList()) { for (Contend contend : project.getContendList()) {
projectBandNames.add(contend.getBrandName());
if (question.contains(contend.getBrandName())) { if (question.contains(contend.getBrandName())) {
contends.add(contend.getId()); contends.add(contend.getId());
} }
} }
if (CollectionUtils.isNotEmpty(contends)){ if (CollectionUtils.isNotEmpty(contends)){
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, String.join("|", contends))); fieldMappings.add(new FieldMapping(FieldMapping.FieldMap.BRAND, String.join("|", contends)));
}else { }else {
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID)); fieldMappings.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID));
} }
}else { }else {
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID)); fieldMappings.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID));
} }
// 防止将项目/品牌作为渠道
fieldMappings.removeIf(fieldMapping -> Objects.equals(fieldMapping.getFieldMap(), FieldMapping.FieldMap.SOURCE) && projectBandNames.contains(String.valueOf(fieldMapping.getValue())));
} }
/** /**
......
...@@ -554,10 +554,13 @@ public class Tools { ...@@ -554,10 +554,13 @@ public class Tools {
String separator = "-"; String separator = "-";
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for (Object obj : objects) { for (Object obj : objects) {
if (Objects.isNull(obj)){
continue;
}
sb.append(obj).append(separator); sb.append(obj).append(separator);
} }
String resultStr = sb.toString(); String resultStr = sb.toString();
return resultStr.substring(0, resultStr.length() - 1); return StringUtils.isBlank(resultStr) ? resultStr : resultStr.substring(0, resultStr.length() - 1);
} }
public static String[] split(String concatStr) { public static String[] split(String concatStr) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment