Commit b60e3109 by shentao

Merge branch 'feature' into 'dev'

2024/8/16 ai搜索调整 test

See merge request !571
parents f1aade7e 2e355acb
......@@ -144,7 +144,7 @@ public class AopLogRecord {
String[] value = method.getAnnotation(LogRecord.class).values();
// 注解value为空字符串(value使用默认值未设置)
if (1 == value.length && StringUtils.isEmpty(value[0])){
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix, userInfo.getRoleId(), now, now);
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix, userInfo.getRoleId(), now, now, null);
}
// 获取接口传参(value为获取传参的具体字段)并与操作描述description拼接返回,传参值为实体
if (method.getAnnotation(LogRecord.class).arguments() && method.getAnnotation(LogRecord.class).entity()) {
......@@ -176,7 +176,7 @@ public class AopLogRecord {
}
}
String suffix = CollectionUtils.isNotEmpty(res) ? "-" + Tools.concatWithMinus(res) : "";
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now);
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now, null);
}
// 获取接口传参(value为获取传参的具体字段)并与操作描述description拼接返回,传参值不为实体
if (method.getAnnotation(LogRecord.class).arguments() && !method.getAnnotation(LogRecord.class).entity()) {
......@@ -190,7 +190,7 @@ public class AopLogRecord {
}
}
String suffix = CollectionUtils.isNotEmpty(res) ? "-" + Tools.concatWithMinus(res) : "";
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now);
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now, null);
}
// 获取接口返回值(value为获取返回值的具体字段)并与操作描述description拼接返回,返回值为实体
if (!method.getAnnotation(LogRecord.class).arguments() && method.getAnnotation(LogRecord.class).entity() && Objects.nonNull(responseResult)) {
......@@ -204,9 +204,9 @@ public class AopLogRecord {
}
}
String suffix = CollectionUtils.isNotEmpty(res) ? "-" + Tools.concatWithMinus(res) : "";
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now);
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix + suffix, userInfo.getRoleId(), now, now, null);
}
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix, userInfo.getRoleId(), now, now);
return new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), prefix, userInfo.getRoleId(), now, now, null);
}
/**
......
......@@ -340,11 +340,6 @@ public class EsClientDao {
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
String[] fetchSource = {"id", GenericAttribute.ES_TIME, GenericAttribute.ES_IND_TITLE};
SearchHit[] hits = searchHits(getIndexes(), query, null, fetchSource, null, 0, 10000, null).getHits();
// Map<String, JSONObject> idBaseMap = Arrays.stream(hits).map(hit -> new JSONObject(hit.getSourceAsMap())).collect(Collectors.toMap(json -> json.getString("id"), o -> o));
// Map<String, String> idTitle = Arrays.stream(hits)
// .map(hit -> new JSONObject(hit.getSourceAsMap()))
// .filter(json -> Objects.nonNull(json.getString(GenericAttribute.ES_IND_TITLE)))
// .collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString(GenericAttribute.ES_IND_TITLE)));
Pair<Map<String, JSONObject>, Map<String, String>> searchProcess = findSearchResultProcess(hits);
Map<String, JSONObject> idBaseMap = searchProcess.getLeft();
Map<String, String> idTitle = searchProcess.getRight();
......@@ -378,7 +373,7 @@ public class EsClientDao {
Map<String, JSONObject> idBaseMap = Arrays.stream(hits).map(hit -> new JSONObject(hit.getSourceAsMap())).collect(Collectors.toMap(json -> json.getString("id"), o -> o));
Map<String, String> idTitle = Arrays.stream(hits)
.map(hit -> new JSONObject(hit.getSourceAsMap()))
.filter(json -> Objects.nonNull(json.getString(GenericAttribute.ES_IND_TITLE)))
.filter(json -> Objects.nonNull(json.getString(GenericAttribute.ES_IND_TITLE)) || Tools.filterUselessTitle(GenericAttribute.ES_IND_TITLE))
.collect(Collectors.toMap(json -> json.getString("id"), json -> json.getString(GenericAttribute.ES_IND_TITLE)));
return Pair.of(idBaseMap, idTitle);
}
......@@ -387,7 +382,7 @@ public class EsClientDao {
fieldMappings.stream().filter(fieldMapping -> Objects.equals(FieldMapping.FieldMap.IND_FULL_TEXT, fieldMapping.getFieldMap()))
.findFirst().ifPresent(fieldMapping -> {
String value = String.valueOf(fieldMapping.getValue());
String newValue = HanLP.segment(Tools.filterSpecialCharacter(value)).stream().map(s -> s.word).distinct().collect(Collectors.joining("|"));
String newValue = HanLP.segment(Tools.filterSpecialCharacter(value)).stream().map(s -> s.word).distinct().collect(Collectors.joining(" "));
fieldMapping.setValue(newValue);
});
BoolQueryBuilder query = getBoolQueryBuilder(fieldMappings);
......
package com.zhiwei.brandkbs2.pojo;
import com.zhiwei.brandkbs2.auth.UserThreadLocal;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.springframework.data.mongodb.core.mapping.Document;
/**
* @ClassName: UserLogRecord
......@@ -14,6 +15,7 @@ import org.springframework.data.mongodb.core.mapping.Document;
@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
public class UserLogRecord extends AbstractBaseMongo{
/**
* 项目id
......@@ -43,4 +45,21 @@ public class UserLogRecord extends AbstractBaseMongo{
* 更新时间
*/
private Long updateTime;
/**
* 本次调用费用(ai搜索)
*/
private Double cost;
public static UserLogRecord userLogRecordCost(String description, double cost){
UserLogRecord record = new UserLogRecord();
record.setProjectId(UserThreadLocal.getProjectId());
record.setUserId(UserThreadLocal.getUserId());
record.setNickname(UserThreadLocal.getNickname());
record.setDescription(description);
record.setRoleId(UserThreadLocal.getRoleId());
record.setCTime(System.currentTimeMillis());
record.setUpdateTime(System.currentTimeMillis());
record.setCost(cost);
return record;
}
}
package com.zhiwei.brandkbs2.pojo.ai;
import com.zhiwei.brandkbs2.auth.UserThreadLocal;
import com.zhiwei.brandkbs2.common.GlobalPojo;
import com.zhiwei.brandkbs2.config.Constant;
import com.zhiwei.brandkbs2.es.EsQueryTools;
import com.zhiwei.brandkbs2.pojo.AbstractProject;
import com.zhiwei.brandkbs2.pojo.Contend;
import com.zhiwei.brandkbs2.pojo.Project;
import lombok.Data;
import lombok.Getter;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.index.query.*;
import java.util.*;
import java.util.stream.Collectors;
/**
* @ClassName: FieldMap
......@@ -35,7 +28,7 @@ public class FieldMapping {
public QueryBuilder buildQuery(FieldMapping fieldMapping) {
boolean existsAnd = null != fieldMapping;
RangeQueryBuilder timeRangeBuilder;
String contendId = "0";
String[] contendIds = {"0"};
// 项目组需绑定查询
switch (fieldMap) {
case START_TIME:
......@@ -52,12 +45,17 @@ public class FieldMapping {
return timeRangeBuilder;
case PROJECT:
if (existsAnd && fieldMapping.fieldMap == FieldMap.BRAND) {
contendId = (String) fieldMapping.value;
contendIds = ((String) fieldMapping.value).split("\\|");
}
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
for (String contendId : contendIds) {
BoolQueryBuilder nestedBoolBuilder = QueryBuilders.boolQuery();
// 必要条件
nestedBoolBuilder.must(QueryBuilders.termQuery(fieldMap.databaseName, value + "_" + contendId));
return new NestedQueryBuilder("brandkbs_cache_maps", nestedBoolBuilder, ScoreMode.None);
boolQueryBuilder.should(new NestedQueryBuilder("brandkbs_cache_maps", nestedBoolBuilder, ScoreMode.None));
}
boolQueryBuilder.minimumShouldMatch(1);
return boolQueryBuilder;
case BRAND:
if (!existsAnd || fieldMapping.fieldMap != FieldMap.PROJECT) {
throw new IllegalStateException("项目条件缺失");
......@@ -65,7 +63,6 @@ public class FieldMapping {
return fieldMapping.buildQuery(this);
case IND_FULL_TEXT:
return EsQueryTools.assembleNormalKeywordQuery(String.valueOf(value), new String[]{fieldMap.databaseName});
// return QueryBuilders.matchPhraseQuery(fieldMap.databaseName, value);
case SOURCE:
case MTAG:
return QueryBuilders.termQuery(fieldMap.databaseName, value);
......@@ -97,41 +94,59 @@ public class FieldMapping {
}
}
public static FieldMapping createFromNameAndValue(String name, Object value, List<FieldMapping> fieldMappings) {
public static FieldMapping createFromNameAndValue(String name, Object value, String question) {
FieldMap fieldMap = null;
// String projectId = UserThreadLocal.getProjectId();
// TODO 字段转换待完善,引入数据库
for (FieldMap f : FieldMap.values()) {
if (name.equals(f.getName())) {
// 项目名需要转成id
if (FieldMap.PROJECT == f) {
Map<String, Project> projectMap = GlobalPojo.PROJECT_MAP.values().stream().collect(Collectors.toMap(AbstractProject::getProjectName, o -> o));
if (projectMap.containsKey(String.valueOf(value))) {
value = projectMap.get(String.valueOf(value)).getId();
}else {
value = UserThreadLocal.getProjectId();
}
}
// // 项目名需要转成id
// if (FieldMap.PROJECT == f) {
// Map<String, Project> projectMap = GlobalPojo.PROJECT_MAP.values().stream().collect(Collectors.toMap(AbstractProject::getProjectName, o -> o));
// if (projectMap.containsKey(String.valueOf(value))) {
// value = projectMap.get(String.valueOf(value)).getId();
// }else {
// value = projectId;
// }
// }
// 品牌需要转换
if (FieldMap.BRAND == f) {
if ("主品牌".equals(value)) {
value = Constant.PRIMARY_CONTEND_ID;
} else {
// 寻找对应的竞品id
Optional<FieldMapping> project = fieldMappings.stream().filter(field -> Objects.equals(FieldMap.PROJECT, field.getFieldMap())).findFirst();
if (project.isPresent()){
List<Contend> contendList = GlobalPojo.PROJECT_MAP.get(String.valueOf(project.get().getValue())).getContendList();
Object finalValue = value;
Optional<Contend> contendOptional = contendList.stream().filter(contend -> Objects.equals(contend.getBrandName(), finalValue)).findFirst();
if (contendOptional.isPresent()){
value = contendOptional.get().getId();
}else {
value = Constant.PRIMARY_CONTEND_ID;
}
}else {
value = Constant.PRIMARY_CONTEND_ID;
}
}
}
// if (FieldMap.BRAND == f) {
// Project project = GlobalPojo.PROJECT_MAP.get(projectId);
// if (CollectionUtils.isNotEmpty(project.getContendList())){
// List<String> contends = new ArrayList<>();
// List<String> contendNames = project.getContendList().stream().map(AbstractProject::getBrandName).collect(Collectors.toList());
// for (String contendName : contendNames) {
// if (question.contains(contendName)) {
// contends.add(contendName);
// }
// }
// if (CollectionUtils.isNotEmpty(contends)){
// value = String.join("|", contends);
// }else {
// value = Constant.PRIMARY_CONTEND_ID;
// }
// }else {
// value = Constant.PRIMARY_CONTEND_ID;
// }
// if ("主品牌".equals(value)) {
// value = Constant.PRIMARY_CONTEND_ID;
// } else {
// // 寻找对应的竞品id
// Optional<FieldMapping> project = fieldMappings.stream().filter(field -> Objects.equals(FieldMap.PROJECT, field.getFieldMap())).findFirst();
// if (project.isPresent()){
// List<Contend> contendList = GlobalPojo.PROJECT_MAP.get(String.valueOf(project.get().getValue())).getContendList();
// Object finalValue = value;
// Optional<Contend> contendOptional = contendList.stream().filter(contend -> Objects.equals(contend.getBrandName(), finalValue)).findFirst();
// if (contendOptional.isPresent()){
// value = contendOptional.get().getId();
// }else {
// value = Constant.PRIMARY_CONTEND_ID;
// }
// }else {
// value = Constant.PRIMARY_CONTEND_ID;
// }
// }
// }
// 标签只包含正负中
if (FieldMap.MTAG == f) {
if (!Arrays.asList("正面", "中性", "负面").contains(String.valueOf(value))) {
......
......@@ -174,7 +174,7 @@ public class BehaviorServiceImpl implements BehaviorService {
if (null == userInfo) {
return;
}
UserLogRecord userLogRecord = new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), description, userInfo.getRoleId(), now, now);
UserLogRecord userLogRecord = new UserLogRecord(projectId, userInfo.getUserId(), userInfo.getNickname(), description, userInfo.getRoleId(), now, now, null);
String collectionName = userLogRecordDao.generateCollectionName();
userLogRecordDao.insertOne(userLogRecord, collectionName);
}
......
......@@ -6,10 +6,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.HanLP;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionResult;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.model.completion.chat.*;
import com.zhiwei.base.category.ClassB;
import com.zhiwei.base.entity.subclass.mark.MarkInfo;
import com.zhiwei.brandkbs2.auth.UserThreadLocal;
......@@ -115,19 +112,18 @@ public class MarkDataServiceImpl implements MarkDataService {
private static final String QUESTION_PROMPT = "###\n" +
"假如你是专业的问题提炼人员,你将根据用户提供的内容,来提炼问题要素和条件。根据以下规则一步步执行:\n" +
"1.提及到年、月、日、天、礼拜的定义为时间要素,未提及到默认条件算近一周,条件给到具体的起始时间和结束时间(结束时间为当前时间则不用返回)的时间戳。\n" +
"2.提及到 XX 项目 XX 品牌的定义为项目及品牌要素,条件给到具体的项目及品牌(未提及品牌默认主品牌)。\n" +
"3.提及到 XX 渠道的定义为渠道要素,条件给到该渠道名。\n" +
"4.提及到正面、中性、负面的定义为标签要素,条件给到该标签名。\n" +
"5.提及到针对 XX ,针对 XX 相关或 XX 相关的定义为搜索条件要素(必须包含 针对/相关 字样),条件给到具体值\n" +
"6.时间和项目要素为必需要素,若不满足则返回“无法回答”。\n" +
"2.提及到 XX 渠道的定义为渠道要素,条件给到该渠道名。\n" +
"3.提及到正面、中性、负面的定义为标签要素,条件给到该标签名。\n" +
"4.提及到针对 XX ,针对 XX 相关或 XX 相关的定义为搜索条件要素(包含 针对/相关 字样)。未提及时按照文义在文段内容中提取的一个或多个词作为该文段的关键词,将其定义为搜索条件要素。多个关键词时每个关键词之间严格用” “(空格)作为分隔符进行分隔,单个关键词则无需分隔,关键词必须在给到的文段中出现,条件给到具体值\n" +
"5.时间和关键词要素为必需要素,若不满足则返回“无法回答”。\n" +
"\n" +
"参考例子:\n" +
"示例 1:\n" +
"{用户:今年 7 月腾讯项目张清相关的正面数据}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000,\"结束时间\":1722355200000},\"项目\":\"腾讯\",\"品牌\":\"主品牌\",\"标签\":\"正面\",\"搜索条件\":\"张清\"}\n" +
"'{用户:今年7月腾讯项目张清相关的正面数据}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000,\"结束时间\":1722355200000},\"标签\":\"正面\",\"搜索条件\":\"张清\"}\n" +
"示例 2:\n" +
"{用户:近一个月老乡鸡竞品1品牌新浪网渠道数据}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000},\"项目\":\"老乡鸡\",\"品牌\":\"竞品1\",\"渠道\":\"新浪网\"}\n" +
"输出:{\"时间\":{\"起始时间\":1719763200000},\"渠道\":\"新浪网\"}\n" +
"示例 3:\n" +
"{用户:近一年数据}\n" +
"输出:无法回答\n" +
......@@ -187,6 +183,9 @@ public class MarkDataServiceImpl implements MarkDataService {
@Resource(name = "xiaohongshuWordDao")
private XiaohongshuWordDao xiaohongshuWordDao;
@Resource(name = "UserLogRecordDao")
private UserLogRecordDao userLogRecordDao;
@Resource(name = "commonServiceImpl")
private CommonService commonService;
......@@ -3962,11 +3961,20 @@ public class MarkDataServiceImpl implements MarkDataService {
@Override
public List<String> getAIReferenceQuestion(String question, int size) {
try {
String projectId = UserThreadLocal.getProjectId();
// 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
String projectName = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId()).getProjectName();
String resultContent = standardRequest(question, modelName, MessageFormat.format(REFERENCE_QUESTION_PROMPT, size, projectName));
AccessModel.Model model = AccessModel.Model.DOUBAO_PRO_32K;
String modelName = model.getModelName();
String projectName = GlobalPojo.PROJECT_MAP.get(projectId).getProjectName();
Pair<String, long[]> pair = standardRequest(question, modelName, MessageFormat.format(REFERENCE_QUESTION_PROMPT, size, projectName));
if (Objects.isNull(pair)){
return getAIReferenceQuestionTemplate(projectName);
}
String resultContent = pair.getLeft();
String[] splits = resultContent.split("\\|");
// 需记录本次耗费
double cost = calculateCost(pair.getRight(), model);
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-生成推荐提问-" + question, cost), userLogRecordDao.generateCollectionName());
return new ArrayList<>(Arrays.asList(splits)).stream().filter(StringUtils::isNoneBlank).map(String::trim).collect(Collectors.toList());
}catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "获取ai参考提问异常-", e);
......@@ -3995,10 +4003,14 @@ public class MarkDataServiceImpl implements MarkDataService {
for (String question : questionList) {
sb.append(count++).append("、").append(question).append(";");
}
String resultContent = standardRequest(sb.toString(), modelName, MessageFormat.format(CACHE_REFERENCE_QUESTION_PROMPT, projectName));
if (Objects.isNull(resultContent)){
Pair<String, long[]> pair = standardRequest(sb.toString(), modelName, MessageFormat.format(CACHE_REFERENCE_QUESTION_PROMPT, projectName));
if (Objects.isNull(pair)){
return getAIReferenceQuestionTemplate(projectName);
}
// 需记录耗费
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-生成参考提问",
calculateCost(pair.getRight(), AccessModel.Model.DOUBAO_PRO_32K)), userLogRecordDao.generateCollectionName());
String resultContent = pair.getLeft();
String[] splits = resultContent.split("\\|");
List<String> result = new ArrayList<>(Arrays.asList(splits)).stream().filter(StringUtils::isNoneBlank).map(String::trim).collect(Collectors.toList());
redisUtil.setExpire(key, JSONObject.toJSONString(result));
......@@ -4021,12 +4033,20 @@ public class MarkDataServiceImpl implements MarkDataService {
// 选用的模型名称
String modelName = AccessModel.Model.DOUBAO_PRO_32K.getModelName();
// 根据AI生成条件
String result = standardRequest(question, modelName, QUESTION_PROMPT);
JSONObject json = JSON.parseObject(result);
Pair<String, long[]> questionPair = standardRequest(question, modelName, QUESTION_PROMPT);
if (Objects.isNull(questionPair)){
return null;
}
JSONObject json = JSON.parseObject(questionPair.getLeft());
// 数据条件
List<FieldMapping> filedMapping = getFiledMapping(json, question);
addDefaultFiledMapping(filedMapping, question);
List<JSONObject> list = esClientDao.findSearch(filedMapping);
String collectionName = userLogRecordDao.generateCollectionName();
if (CollectionUtils.isEmpty(list)){
// 需记录耗费
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-无结果-提取搜索条件-" + question,
calculateCost(questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K)), collectionName);
return null;
}
// AI回答
......@@ -4038,8 +4058,8 @@ public class MarkDataServiceImpl implements MarkDataService {
sb.append(count++).append("、").append(text).append(";");
}
String sbContent = sb.toString();
result = streamStandardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question);
String[] splits = result.split("\\r?\\n");
Pair<String, long[]> answerPair = streamStandardRequest(sbContent, modelName, MessageFormat.format(RESULT_PROMPT, list.size()) + question);
String[] splits = answerPair.getLeft().split("\\r?\\n");
List<JSONObject> answers = new ArrayList<>();
for (int i = 0; i < splits.length; i++) {
JSONObject answer = new JSONObject();
......@@ -4059,8 +4079,12 @@ public class MarkDataServiceImpl implements MarkDataService {
}
res.put("answers", answers);
res.put("articles", articles);
res.put("searchCriteria", questionPair.getLeft());
// 记录返回成功的提问
aiSearchQuestionRecordDao.insertOne(new AISearchQuestionRecord(question, UserThreadLocal.getProjectId(), System.currentTimeMillis()));
// 需记录耗费
double cost = calculateCost(questionPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K) + calculateCost(answerPair.getRight(), AccessModel.Model.DOUBAO_PRO_32K);
userLogRecordDao.insertOne(UserLogRecord.userLogRecordCost("AI搜索-结果-" + question, cost), collectionName);
return res;
}catch (Exception e){
ExceptionCast.cast(CommonCodeEnum.FAIL, "ai搜索异常-", e);
......@@ -4068,17 +4092,31 @@ public class MarkDataServiceImpl implements MarkDataService {
return null;
}
private String streamStandardRequest(String content, String modelName, String prompt) {
private double calculateCost(long[] tokens, AccessModel.Model model){
double inputCost = tokens[0] / 1000d * model.getInputPrice();
double outputCost = tokens[1] / 1000d * model.getOutputPrice();
return inputCost + outputCost;
}
private Pair<String, long[]> streamStandardRequest(String content, String modelName, String prompt) {
AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName);
StringBuilder result = new StringBuilder();
AtomicLong promptTokens = new AtomicLong();
AtomicLong completionTokens = new AtomicLong();
try {
final List<ChatMessage> streamMessages = new ArrayList<>();
final ChatMessage streamSystemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(prompt).build();
final ChatMessage streamUserMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
streamMessages.add(streamSystemMessage);
streamMessages.add(streamUserMessage);
ChatCompletionRequest streamChatCompletionRequest = ChatCompletionRequest.builder().model(model.getModelId()).messages(streamMessages).build();
ChatCompletionRequest streamChatCompletionRequest = ChatCompletionRequest.builder().stream(true)
.streamOptions(ChatCompletionRequest.ChatCompletionRequestStreamOptions.of(true)).model(model.getModelId()).messages(streamMessages).build();
DoubaoAIAccountFactor.arkService.streamChatCompletion(streamChatCompletionRequest).doOnError(Throwable::printStackTrace).blockingForEach(choice -> {
if (Objects.nonNull(choice.getUsage())){
// 本次调用输入、输出使用tokens量
promptTokens.set(choice.getUsage().getPromptTokens());
completionTokens.set(choice.getUsage().getCompletionTokens());
}
if (choice.getChoices().size() > 0) {
result.append(choice.getChoices().get(0).getMessage().getContent());
}
......@@ -4086,10 +4124,11 @@ public class MarkDataServiceImpl implements MarkDataService {
} catch (Exception e) {
log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(result), e);
}
return result.toString();
long[] tokens = {promptTokens.get(), completionTokens.get()};
return Pair.of(result.toString(), tokens);
}
private String standardRequest(String content, String modelName, String prompt) {
private Pair<String, long[]> standardRequest(String content, String modelName, String prompt) {
AccessModel model = DoubaoAIAccountFactor.getCompanyAccount().getModelList().stream().collect(Collectors.toMap(AccessModel::getModelName, m -> m)).get(modelName);
ChatCompletionResult chatCompletion = null;
try {
......@@ -4104,7 +4143,10 @@ public class MarkDataServiceImpl implements MarkDataService {
log.error("异常chatCompletion:{}", JSON.toJSONString(chatCompletion));
return null;
}
return String.valueOf(chatCompletion.getChoices().get(0).getMessage().getContent());
// 本次调用输入、输出使用tokens量
long[] tokens = {chatCompletion.getUsage().getPromptTokens(), chatCompletion.getUsage().getCompletionTokens()};
String resultContent = String.valueOf(chatCompletion.getChoices().get(0).getMessage().getContent());
return Pair.of(resultContent, tokens);
} catch (Exception e) {
log.error("standardRequest,chatCompletion:{}", JSON.toJSONString(chatCompletion), e);
}
......@@ -4124,12 +4166,12 @@ public class MarkDataServiceImpl implements MarkDataService {
res.addAll(getFiledMapping((JSONObject) entry.getValue(), content));
} else {
// 文本限定关键字
if (entry.getKey().equals("搜索条件")) {
if (!(content.contains("针对") || content.contains("相关"))) {
continue;
}
}
FieldMapping fieldMapping = FieldMapping.createFromNameAndValue(entry.getKey(), entry.getValue(), res);
// if (entry.getKey().equals("搜索条件")) {
// if (!(content.contains("针对") || content.contains("相关"))) {
// continue;
// }
// }
FieldMapping fieldMapping = FieldMapping.createFromNameAndValue(entry.getKey(), entry.getValue(), content);
if (null != fieldMapping) {
res.add(fieldMapping);
}
......@@ -4138,6 +4180,26 @@ public class MarkDataServiceImpl implements MarkDataService {
return res;
}
private void addDefaultFiledMapping(List<FieldMapping> filedMapping, String question){
Project project = GlobalPojo.PROJECT_MAP.get(UserThreadLocal.getProjectId());
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.PROJECT, UserThreadLocal.getProjectId()));
if (CollectionUtils.isNotEmpty(project.getContendList())){
List<String> contends = new ArrayList<>();
for (Contend contend : project.getContendList()) {
if (question.contains(contend.getBrandName())) {
contends.add(contend.getId());
}
}
if (CollectionUtils.isNotEmpty(contends)){
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, String.join("|", contends)));
}else {
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID));
}
}else {
filedMapping.add(new FieldMapping(FieldMapping.FieldMap.BRAND, Constant.PRIMARY_CONTEND_ID));
}
}
/**
* 原发溯源大库es查询
* @param dto
......
......@@ -493,7 +493,11 @@ public class TaskServiceImpl implements TaskService {
public void cacheAIQuestion() {
AtomicInteger total = new AtomicInteger();
CompletableFuture.allOf(GlobalPojo.PROJECT_MAP.values().stream().map(project -> CompletableFuture.supplyAsync(() -> {
UserThreadLocal.set(new UserInfo().setProjectId(project.getId()));
UserInfo userInfo = new UserInfo().setProjectId(project.getId());
userInfo.setUserId("0");
userInfo.setNickname("系统");
userInfo.setRoleId(1);
UserThreadLocal.set(userInfo);
markDataService.getAIReferenceQuestionCache(false);
log.info("项目:{}-{}-AI参考问题缓存完成:{}个", project.getProjectName(), project.getId(), total.incrementAndGet());
return null;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment