ai-search 搜索代码实现 WebSocket 版本
本文档详细描述了基于 WebSocket 实现的 ai-search 搜索系统的代码结构。系统主要包括以下模块:
- 实体类:用于封装请求、响应、消息、历史记录以及网页搜索结果等数据结构。
- WebSocket 处理器:负责处理客户端连接、握手、消息收发以及关闭连接等操作。
- 业务服务类:包括消息分发、搜索处理、问题重写、上下文记录、推理及结果返回。
- 搜索服务:调用第三方搜索接口(如 Google Custom Search、SearxNG)进行搜索,并对搜索结果进行处理。
- 预测服务:调用 Gemini 或其他模型 API 进行流式或一次性响应生成。
- 回调类:处理第三方服务的异步响应,并将结果通过 WebSocket 返回给客户端。
下面依次介绍各个部分的代码实现和详细说明。
1. 实体类
实体类用于定义系统中数据的基本结构,下面是各个实体类的源码及说明:
1.1 ChatParamVo
该类封装了对话相关的参数,如重写后的问题、输入提示、历史记录、来源、回答消息 ID 以及搜索到的网页资源列表。
package com.litongjava.perplexica.vo;
import java.util.List;
import com.litongjava.model.web.WebPageContent;
import com.litongjava.openai.chat.ChatMessage;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class ChatParamVo {
private String rewrited;
private String inputPrompt;
private List<ChatMessage> history;
private String from;
private long answerMessageId;
private List<WebPageContent> sources;
}
1.2 ChatReqMessage
用于封装聊天请求中的基本消息数据,包括消息 ID、聊天 ID 和消息内容。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Data
public class ChatReqMessage {
private Long messageId;
private Long chatId;
private String content;
}
1.3 ChatSignalVo
用于封装 WebSocket 信号类型和数据。例如,在连接建立时发送 signal
消息。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class ChatSignalVo {
private String type;
private String data;
}
1.4 ChatWsReqMessageVo
封装 WebSocket 请求消息的结构,包含消息类型、用户 ID、实际消息体、附件文件列表、焦点模式、是否启用协同助手、优化模式和历史消息记录。
package com.litongjava.perplexica.vo;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Data
public class ChatWsReqMessageVo {
private String type;
private Long userId;
private ChatReqMessage message;
private List<String> files;
private String focusMode;
private Boolean copilotEnabled;
private String optimizationMode;
private List<List<String>> history;
}
1.5 ChatWsRespVo
用于封装 WebSocket 响应消息,包括类型、数据、消息 ID 和关键字。该类中还提供了一些静态方法用于快速构造错误、进度、数据、心跳、消息以及消息结束的响应对象。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
@Data
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
public class ChatWsRespVo<T> {
private String type;
private T data;
private Long messageId;
private String key;
public static ChatWsRespVo<String> error(String key, String message) {
return new ChatWsRespVo<String>().setType("error").setKey(key).setData(message);
}
public static ChatWsRespVo<String> progress(String message) {
return new ChatWsRespVo<String>().setType("progress").setData(message);
}
public static <T> ChatWsRespVo<T> data(String type, T data) {
return new ChatWsRespVo<T>().setType(type).setData(data);
}
public static <T> ChatWsRespVo<T> keepAlive(Long answerMessageId) {
return new ChatWsRespVo<T>().setType("keep-alive").setMessageId(answerMessageId);
}
public static <T> ChatWsRespVo<T> message(Long answerMessageId, T data) {
return new ChatWsRespVo<T>().setType("message").setMessageId(answerMessageId).setData(data);
}
public static ChatWsRespVo<Void> messageEnd(Long answerMessageId) {
return new ChatWsRespVo<Void>().setType("messageEnd").setMessageId(answerMessageId);
}
}
1.6 CitationsVo
用于封装引用的标题和链接信息。注意:文档中出现了两次完全相同的定义,均展示如下。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class CitationsVo {
private String title;
private String link;
public CitationsVo(String link) {
this.link = link;
}
}
重复定义(请注意实际项目中只需保留一份):
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class CitationsVo {
private String title;
private String link;
public CitationsVo(String link) {
this.link = link;
}
}
1.7 SearchResultVo
封装搜索结果的 Markdown 格式内容和最终回答文本。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SearchResultVo {
private String markdown;
private String answer;
}
1.8 WebPageMetadata
用于封装网页的元数据信息,如标题和 URL。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class WebPageMetadata {
private String title;
private String url;
}
1.9 WebPageSource
封装网页的内容和对应的元数据。该类提供了多个构造方法以方便创建时传入不同的数据。
package com.litongjava.perplexica.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Data
public class WebPageSource {
private String pageContent;
private WebPageMetadata metadata;
public WebPageSource(String title, String url, String content) {
WebPageMetadata webPageMetadata = new WebPageMetadata(title, url);
this.pageContent = content;
this.metadata = webPageMetadata;
}
public WebPageSource(String title, String url) {
WebPageMetadata webPageMetadata = new WebPageMetadata(title, url);
this.metadata = webPageMetadata;
}
}
2. ChatWebSocketHandler
该类实现了 WebSocket 的处理接口,负责握手、连接建立后向客户端发送初始信号、处理文本消息、二进制消息以及连接关闭事件。下面代码展示了如何在握手完成后根据请求的来源设置不同的搜索引擎 ID(CSE_ID)和来源标识(FROM),并向客户端发送 signal
消息表示连接成功。
package com.litongjava.perplexica.handler;
import com.alibaba.fastjson2.JSONObject;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.perplexica.consts.WebSiteNames;
import com.litongjava.perplexica.services.WsChatService;
import com.litongjava.perplexica.vo.ChatSignalVo;
import com.litongjava.perplexica.vo.ChatWsReqMessageVo;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.http.common.HttpRequest;
import com.litongjava.tio.http.common.HttpResponse;
import com.litongjava.tio.http.common.RequestHeaderKey;
import com.litongjava.tio.utils.environment.EnvUtils;
import com.litongjava.tio.utils.json.FastJson2Utils;
import com.litongjava.tio.utils.json.JsonUtils;
import com.litongjava.tio.websocket.common.WebSocketRequest;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import com.litongjava.tio.websocket.server.handler.IWebSocketHandler;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ChatWebSocketHandler implements IWebSocketHandler {
public static final String CHARSET = "utf-8";
/**
* 握手成功后执行,绑定群组并通知其他用户
*/
public HttpResponse handshake(HttpRequest httpRequest, HttpResponse response, ChannelContext channelContext) throws Exception {
return response;
}
/**
* 处理文本消息,并进行消息广播
*/
public void onAfterHandshaked(HttpRequest httpRequest, HttpResponse httpResponse, ChannelContext channelContext) throws Exception {
String origin = httpRequest.getOrigin();
String host = httpRequest.getHost();
String cesId = null;
String from = null;
// 根据请求来源选择对应的 Google Custom Search JSON API ID
if ("https://sjsu.mycounsellor.ai".equals(origin)) {
cesId = EnvUtils.getStr("SJSU_CSE_ID");
from = WebSiteNames.SJSU;
} else if ("https://hawaii.mycounsellor.ai".equals(origin)) {
cesId = EnvUtils.getStr("HAWAII_CSE_ID");
from = WebSiteNames.HAWAII;
} else if ("https://stanford.mycounsellor.ai".equals(origin)) {
cesId = EnvUtils.getStr("STANFORD_CSE_ID");
from = WebSiteNames.HAWAII;
} else if ("https://berkeley.mycounsellor.ai".equals(origin)) {
cesId = EnvUtils.getStr("BERKELEY_CSE_ID");
from = WebSiteNames.BERKELEY;
} else {
cesId = EnvUtils.getStr("CSE_ID");
from = WebSiteNames.ALL;
}
channelContext.setAttribute("CSE_ID", cesId);
channelContext.setAttribute("FROM", from);
channelContext.setAttribute(RequestHeaderKey.Origin, origin);
channelContext.setAttribute(RequestHeaderKey.Host, host);
log.info("open:{},{},{}", channelContext.getClientIpAndPort(), from, cesId);
String json = JsonUtils.toJson(new ChatSignalVo("signal", "open"));
WebSocketResponse webSocketResponse = WebSocketResponse.fromText(json, CHARSET);
Tio.send(channelContext, webSocketResponse);
}
/**
* 处理连接关闭请求,进行资源清理
*/
public Object onClose(WebSocketRequest wsRequest, byte[] bytes, ChannelContext channelContext) throws Exception {
Tio.remove(channelContext, "客户端主动关闭连接");
return null;
}
/**
* 处理二进制消息
*/
public Object onBytes(WebSocketRequest wsRequest, byte[] bytes, ChannelContext channelContext) throws Exception {
log.info("size:{}", bytes.length);
return null;
}
/**
* 处理文本消息
*/
public Object onText(WebSocketRequest wsRequest, String text, ChannelContext channelContext) throws Exception {
JSONObject reqJsonObject = FastJson2Utils.parseObject(text);
String type = reqJsonObject.getString("type");
if ("message".equals(type)) {
ChatWsReqMessageVo vo = FastJson2Utils.parse(text, ChatWsReqMessageVo.class);
log.info("message:{}", text);
try {
Aop.get(WsChatService.class).dispatch(channelContext, vo);
} catch (Exception e) {
log.error(e.getMessage(), e);
ChatWsRespVo<String> error = ChatWsRespVo.error(e.getClass().toGenericString(), e.getMessage());
WebSocketResponse packet = WebSocketResponse.fromJson(error);
Tio.bSend(channelContext, packet);
}
}
return null; // 不需要额外的返回值
}
}
说明
在 onAfterHandshaked
方法中,根据请求的 origin
设置了对应的搜索引擎 ID(CSE_ID)和来源标识(FROM),并在握手成功后向客户端发送了一个 signal
消息。
在 onText
方法中,将接收到的文本消息解析为 ChatWsReqMessageVo
对象,并调用 WsChatService.dispatch
方法处理消息分发和业务逻辑。
3. WsChatService
该服务类负责处理收到的 WebSocket 消息,将用户请求进行拆分、历史记录查询、问题重写、消息保存以及调用不同的预测或搜索服务,最终通过 WebSocket 返回响应给客户端。
3.1 WsChatService(主要部分)
package com.litongjava.perplexica.services;
import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.Lock;
import com.google.common.util.concurrent.Striped;
import com.jfinal.kit.Kv;
import com.litongjava.db.activerecord.Db;
import com.litongjava.gemini.GoogleGeminiModels;
import com.litongjava.google.search.GoogleCustomSearchResponse;
import com.litongjava.google.search.SearchResultItem;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.model.web.WebPageContent;
import com.litongjava.openai.chat.ChatMessage;
import com.litongjava.openai.chat.OpenAiChatMessage;
import com.litongjava.openai.chat.OpenAiChatRequestVo;
import com.litongjava.openai.client.OpenAiClient;
import com.litongjava.openai.constants.PerplexityConstants;
import com.litongjava.openai.constants.PerplexityModels;
import com.litongjava.perplexica.callback.SearchGeminiSseCallback;
import com.litongjava.perplexica.callback.PerplexiticySeeCallback;
import com.litongjava.perplexica.can.ChatWsStreamCallCan;
import com.litongjava.perplexica.consts.FocusMode;
import com.litongjava.perplexica.consts.PerTableNames;
import com.litongjava.perplexica.model.PerplexicaChatMessage;
import com.litongjava.perplexica.model.PerplexicaChatSession;
import com.litongjava.perplexica.vo.ChatParamVo;
import com.litongjava.perplexica.vo.ChatReqMessage;
import com.litongjava.perplexica.vo.ChatWsReqMessageVo;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.perplexica.vo.CitationsVo;
import com.litongjava.perplexica.vo.WebPageSource;
import com.litongjava.template.PromptEngine;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.utils.environment.EnvUtils;
import com.litongjava.tio.utils.json.JsonUtils;
import com.litongjava.tio.utils.snowflake.SnowflakeIdUtils;
import com.litongjava.tio.utils.tag.TagUtils;
import com.litongjava.tio.utils.thread.TioThreadUtils;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Call;
import okhttp3.Callback;
@Slf4j
public class WsChatService {
private static final Striped<Lock> sessionLocks = Striped.lock(1024);
GeminiPredictService geminiPredictService = Aop.get(GeminiPredictService.class);
private AiSearchService aiSerchService = Aop.get(AiSearchService.class);
/**
* 使用搜索模型处理消息
*/
public void dispatch(ChannelContext channelContext, ChatWsReqMessageVo reqMessageVo) {
ChatReqMessage message = reqMessageVo.getMessage();
Long userId = reqMessageVo.getUserId();
Long sessionId = message.getChatId();
Long messageQuestionId = message.getMessageId();
String content = message.getContent();
ChatParamVo chatParamVo = new ChatParamVo();
// create chat or save message
String focusMode = reqMessageVo.getFocusMode();
if (!Db.exists(PerTableNames.max_search_chat_session, "id", sessionId)) {
Lock lock = sessionLocks.get(sessionId);
lock.lock();
try {
TioThreadUtils.execute(() -> {
String summary = Aop.get(SummaryQuestionService.class).summary(content);
new PerplexicaChatSession().setId(sessionId).setUserId(userId).setTitle(summary).setFocusMode(focusMode).save();
});
} finally {
lock.unlock();
}
}
// query history
List<ChatMessage> history = Aop.get(ChatMessgeService.class).getHistoryById(sessionId);
chatParamVo.setHistory(history);
if (content.length() > 30 || history.size() > 0) {
String rewrited = Aop.get(RewriteQuestionService.class).rewrite(content, history);
chatParamVo.setRewrited(rewrited);
if (channelContext != null) {
Kv end = Kv.by("type", "rewrited").set("content", rewrited);
Tio.bSend(channelContext, WebSocketResponse.fromJson(end));
}
}
// save user message
new PerplexicaChatMessage().setId(messageQuestionId).setChatId(sessionId)
.setRole("user").setContent(content).save();
String from = channelContext.getString("FROM");
chatParamVo.setFrom(from);
Boolean copilotEnabled = reqMessageVo.getCopilotEnabled();
Call call = null;
long answerMessageId = SnowflakeIdUtils.id();
chatParamVo.setAnswerMessageId(answerMessageId);
log.info("focusMode:{},{}", userId, focusMode);
if (FocusMode.webSearch.equals(focusMode)) {
call = aiSerchService.search(channelContext, reqMessageVo, chatParamVo);
} else if (FocusMode.translator.equals(focusMode)) {
String inputPrompt = Aop.get(TranslatorPromptService.class).genInputPrompt(channelContext, content, copilotEnabled, messageQuestionId, messageQuestionId, from);
chatParamVo.setInputPrompt(inputPrompt);
call = geminiPredictService.predict(channelContext, reqMessageVo, chatParamVo);
} else if (FocusMode.deepSeek.equals(focusMode)) {
Aop.get(DeepSeekPredictService.class).predict(channelContext, reqMessageVo, sessionId, messageQuestionId, answerMessageId, content, null);
} else if (FocusMode.mathAssistant.equals(focusMode)) {
String inputPrompt = PromptEngine.renderToString("math_assistant_prompt.txt");
Aop.get(DeepSeekPredictService.class).predict(channelContext, reqMessageVo, sessionId, messageQuestionId, answerMessageId, content, inputPrompt);
} else if (FocusMode.writingAssistant.equals(focusMode)) {
String inputPrompt = PromptEngine.renderToString("writing_assistant_prompt.txt");
Aop.get(DeepSeekPredictService.class).predict(channelContext, reqMessageVo, sessionId, messageQuestionId, answerMessageId, content, inputPrompt);
} else {
// 向前端通知一个空消息,标识搜索结束,开始推理
ChatWsRespVo<String> chatVo = ChatWsRespVo.message(answerMessageId, "");
WebSocketResponse websocketResponse = WebSocketResponse.fromJson(chatVo);
if (channelContext != null) {
Tio.bSend(channelContext, websocketResponse);
}
chatVo = ChatWsRespVo.message(answerMessageId, "Sorry Developing");
websocketResponse = WebSocketResponse.fromJson(chatVo);
if (channelContext != null) {
Tio.bSend(channelContext, websocketResponse);
Kv end = Kv.by("type", "messageEnd").set("messageId", answerMessageId);
Tio.bSend(channelContext, WebSocketResponse.fromJson(end));
}
}
if (call != null) {
ChatWsStreamCallCan.put(sessionId.toString(), call);
}
}
public Call google(ChannelContext channelContext, Long sessionId, Long messageId, String content) {
String cseId = (String) channelContext.getString("CSE_ID");
long answerMessageId = SnowflakeIdUtils.id();
//1. 问题重写(省略部分实现)
//2. 搜索
GoogleCustomSearchResponse search = Aop.get(GoogleCustomSearchService.class).search(cseId, content);
List<SearchResultItem> items = search.getItems();
List<WebPageContent> results = new ArrayList<>(items.size());
for (SearchResultItem searchResultItem : items) {
String title = searchResultItem.getTitle();
String link = searchResultItem.getLink();
String snippet = searchResultItem.getSnippet();
WebPageContent searchSimpleResult = new WebPageContent(title, link, snippet);
results.add(searchSimpleResult);
}
//3. 选择:构造提示词,并调用 Gemini 进行生成
Kv kv = Kv.by("quesiton", content).set("search_result", JsonUtils.toJson(results));
String fileName = "WebSearchSelectPrompt.txt";
String prompt = PromptEngine.renderToString(fileName, kv);
log.info("WebSearchSelectPrompt:{}", prompt);
String selectResultContent = Aop.get(GeminiService.class).generate(prompt);
List<String> outputs = TagUtils.extractOutput(selectResultContent);
String titleAndLinks = outputs.get(0);
if ("not_found".equals(titleAndLinks)) {
if (channelContext != null) {
ChatWsRespVo<String> vo = ChatWsRespVo.message(answerMessageId, "");
Tio.bSend(channelContext, WebSocketResponse.fromJson(vo));
vo = ChatWsRespVo.message(messageId, "Sorry,not found");
log.info("not found:{}", content);
Tio.bSend(channelContext, WebSocketResponse.fromJson(vo));
}
return null;
}
//4. 将搜索结果转换为引用列表,并返回 sources 给客户端
String[] split = titleAndLinks.split("\n");
List<CitationsVo> citationList = new ArrayList<>();
for (int i = 0; i < split.length; i++) {
String[] split2 = split[i].split("~~");
citationList.add(new CitationsVo(split2[0], split2[1]));
}
if (citationList.size() > 0) {
List<WebPageSource> sources = Aop.get(WebpageSourceService.class).getListWithCitationsVo(citationList);
ChatWsRespVo<List<WebPageSource>> chatRespVo = new ChatWsRespVo<>();
chatRespVo.setType("sources").setData(sources).setMessageId(answerMessageId);
WebSocketResponse packet = WebSocketResponse.fromJson(chatRespVo);
if (channelContext != null) {
Tio.bSend(channelContext, packet);
}
}
// 通知客户端,搜索结束,开始推理
ChatWsRespVo<String> vo = ChatWsRespVo.message(answerMessageId, "");
WebSocketResponse websocketResponse = WebSocketResponse.fromJson(vo);
if (channelContext != null) {
Tio.bSend(channelContext, websocketResponse);
}
StringBuffer pageContents = Aop.get(SpiderService.class).spiderAsync(channelContext, answerMessageId, citationList);
//6. 推理:构造提示词,并调用 Gemini 进行生成
String isoTimeStr = DateTimeFormatter.ISO_INSTANT.format(Instant.now());
kv = Kv.by("date", isoTimeStr).set("context", pageContents.toString());
String webSearchResponsePrompt = PromptEngine.renderToString("WebSearchResponsePrompt.txt", kv);
log.info("webSearchResponsePrompt:{}", webSearchResponsePrompt);
List<OpenAiChatMessage> messages = new ArrayList<>();
messages.add(new OpenAiChatMessage("assistant", webSearchResponsePrompt));
messages.add(new OpenAiChatMessage(content));
OpenAiChatRequestVo chatRequestVo = new OpenAiChatRequestVo().setModel(GoogleGeminiModels.GEMINI_2_0_FLASH_EXP)
.setMessages(messages).setMax_tokens(3000);
chatRequestVo.setStream(true);
long start = System.currentTimeMillis();
Callback callback = new SearchGeminiSseCallback(channelContext, sessionId, messageId, answerMessageId, start);
Call call = Aop.get(GeminiService.class).stream(chatRequestVo, callback);
return call;
}
@SuppressWarnings("unused")
private Call ppl(ChannelContext channelContext, String sessionId, String messageId, List<OpenAiChatMessage> messages) {
OpenAiChatRequestVo chatRequestVo = new OpenAiChatRequestVo().setModel(PerplexityModels.LLAMA_3_1_SONAR_LARGE_128K_ONLINE)
.setMessages(messages).setMax_tokens(3000).setStream(true);
log.info("chatRequestVo:{}", JsonUtils.toJson(chatRequestVo));
String pplApiKey = EnvUtils.get("PERPLEXITY_API_KEY");
chatRequestVo.setStream(true);
long start = System.currentTimeMillis();
Callback callback = new PerplexiticySeeCallback(channelContext, sessionId, messageId, start);
Call call = OpenAiClient.chatCompletions(PerplexityConstants.SERVER_URL, pplApiKey, chatRequestVo, callback);
return call;
}
}
说明
- 在
dispatch
方法中,首先判断当前会话是否存在,不存在则创建新的会话并保存用户消息,同时查询历史记录并执行问题重写。 - 根据
focusMode
(搜索、翻译、深度检索、数学助手、写作助手等)选择调用不同的处理服务(如aiSerchService.search
、geminiPredictService.predict
等)。 - 对于 webSearch 模式,还调用了 Google Custom Search 或其他搜索服务,并将搜索结果以引用形式返回给客户端。
4. AiSearchService
该服务类主要负责处理搜索逻辑。通过调用 SearxNG 搜索接口获取搜索结果,并对结果进行排序、筛选后返回给客户端。同时构造提示词供后续推理使用。
package com.litongjava.perplexica.services;
import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import org.postgresql.util.PGobject;
import com.jfinal.kit.Kv;
import com.litongjava.db.activerecord.Db;
import com.litongjava.jfinal.aop.Aop;
import com.litongjava.kit.PgObjectUtils;
import com.litongjava.model.web.WebPageContent;
import com.litongjava.perplexica.consts.OptimizationMode;
import com.litongjava.perplexica.vo.ChatParamVo;
import com.litongjava.perplexica.vo.ChatWsReqMessageVo;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.perplexica.vo.WebPageSource;
import com.litongjava.searxng.SearxngResult;
import com.litongjava.searxng.SearxngSearchClient;
import com.litongjava.searxng.SearxngSearchParam;
import com.litongjava.searxng.SearxngSearchResponse;
import com.litongjava.template.PromptEngine;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.http.common.RequestHeaderKey;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import okhttp3.Call;
public class AiSearchService {
GeminiPredictService geminiPredictService = Aop.get(GeminiPredictService.class);
public boolean spped = true;
public Call search(ChannelContext channelContext, ChatWsReqMessageVo reqMessageVo, ChatParamVo chatParamVo) {
String optimizationMode = reqMessageVo.getOptimizationMode();
Boolean copilotEnabled = reqMessageVo.getCopilotEnabled();
String content = reqMessageVo.getMessage().getContent();
Long questionMessageId = reqMessageVo.getMessage().getMessageId();
long answerMessageId = chatParamVo.getAnswerMessageId();
String inputPrompt = null;
if (copilotEnabled != null && copilotEnabled) {
String quesiton = null;
if (chatParamVo.getRewrited() != null) {
quesiton = chatParamVo.getRewrited();
} else {
quesiton = content;
}
// 1. 进行搜索(可选:SearxNG)
SearxngSearchParam searxngSearchParam = new SearxngSearchParam();
searxngSearchParam.setFormat("json");
searxngSearchParam.setQ(quesiton);
SearxngSearchResponse searchResponse = SearxngSearchClient.search(searxngSearchParam);
List<SearxngResult> results = searchResponse.getResults();
List<WebPageContent> webPageContents = new ArrayList<>();
for (SearxngResult searxngResult : results) {
String title = searxngResult.getTitle();
String url = searxngResult.getUrl();
WebPageContent webpageContent = new WebPageContent(title, url);
webpageContent.setContent(searxngResult.getContent());
webPageContents.add(webpageContent);
}
if (OptimizationMode.balanced.equals(optimizationMode)) {
List<WebPageContent> rankedWebPageContents = Aop.get(VectorRankerService.class).filter(webPageContents, quesiton, 1);
rankedWebPageContents = Aop.get(JinaReaderService.class).spider(webPageContents);
webPageContents.set(0, rankedWebPageContents.get(0));
} else if (OptimizationMode.quality.equals(optimizationMode)) {
// 对搜索结果进行质量过滤,或使用 Jina Reader API 读取页面内容
webPageContents = Aop.get(AiRankerService.class).filter(webPageContents, quesiton, 6);
webPageContents = Aop.get(JinaReaderService.class).spiderAsync(webPageContents);
}
chatParamVo.setSources(webPageContents);
// 将搜索结果保存到数据库(保存到历史消息中)
PGobject pgObject = PgObjectUtils.json(webPageContents);
Db.updateBySql("update max_search_chat_message set sources=? where id=?", pgObject, questionMessageId);
if (channelContext != null) {
List<WebPageSource> sources = new ArrayList<>();
for (WebPageContent webPageConteont : webPageContents) {
sources.add(new WebPageSource(webPageConteont.getTitle(), webPageConteont.getUrl(), webPageConteont.getContent()));
}
String host = channelContext.getString(RequestHeaderKey.Host);
if (host == null) {
host = "//127.0.0.1";
} else {
host = "//" + host;
}
sources.add(new WebPageSource("All Sources", host + "/sources/" + questionMessageId));
// 返回 sources 给客户端
ChatWsRespVo<List<WebPageSource>> chatRespVo = new ChatWsRespVo<>();
chatRespVo.setType("sources").setData(sources).setMessageId(answerMessageId);
WebSocketResponse packet = WebSocketResponse.fromJson(chatRespVo);
Tio.bSend(channelContext, packet);
}
StringBuffer markdown = new StringBuffer();
for (int i = 0; i < webPageContents.size(); i++) {
WebPageContent webPageContent = webPageContents.get(i);
markdown.append("source " + (i + 1) + " " + webPageContent.getContent());
}
// 3. 构造提示词:使用 PromptEngine 模板引擎填充提示词
String isoTimeStr = DateTimeFormatter.ISO_INSTANT.format(Instant.now());
Kv kv = Kv.by("date", isoTimeStr).set("context", markdown);
inputPrompt = PromptEngine.renderToString("WebSearchResponsePrompt.txt", kv);
}
chatParamVo.setInputPrompt(inputPrompt);
return geminiPredictService.predict(channelContext, reqMessageVo, chatParamVo);
}
}
说明
- 当启用 copilot 时,系统使用 SearxNG 搜索服务进行网页搜索,获取搜索结果后对搜索结果进行排序或质量过滤(取决于优化模式),并将搜索结果通过 WebSocket 返回给客户端。
- 同时,利用模板引擎生成推理所需的提示词,最后调用 GeminiPredictService 进行推理生成。
5. GeminiPredictService
该服务类负责调用 Gemini 模型进行预测(生成回答)。它会根据当前对话历史、提示词和用户问题构造请求,调用 Gemini 客户端进行流式响应或一次性响应生成。
package com.litongjava.perplexica.services;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import com.litongjava.gemini.GeminiChatRequestVo;
import com.litongjava.gemini.GeminiChatResponseVo;
import com.litongjava.gemini.GeminiClient;
import com.litongjava.gemini.GeminiContentVo;
import com.litongjava.gemini.GeminiPartVo;
import com.litongjava.gemini.GeminiSystemInstructionVo;
import com.litongjava.gemini.GoogleGeminiModels;
import com.litongjava.openai.chat.ChatMessage;
import com.litongjava.perplexica.callback.SearchGeminiSseCallback;
import com.litongjava.perplexica.consts.FocusMode;
import com.litongjava.perplexica.vo.ChatParamVo;
import com.litongjava.perplexica.vo.ChatWsReqMessageVo;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.utils.json.JsonUtils;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Call;
import okhttp3.Callback;
@Slf4j
public class GeminiPredictService {
public Call predict(ChannelContext channelContext, ChatWsReqMessageVo reqMessageVo,
ChatParamVo chatParamVo) {
Long sessionId = reqMessageVo.getMessage().getChatId();
Long quesitonMessageId = reqMessageVo.getMessage().getMessageId();
String content = reqMessageVo.getMessage().getContent();
Long answerMessageId = chatParamVo.getAnswerMessageId();
String inputPrompt = chatParamVo.getInputPrompt();
List<GeminiContentVo> contents = new ArrayList<>();
// 1. 如果有对话历史,则构建 role = user / model 的上下文内容
List<ChatMessage> history = chatParamVo.getHistory();
if (history != null) {
if (history != null && history.size() > 0) {
for (int i = 0; i < history.size(); i++) {
ChatMessage chatMessage = history.get(i);
String role = chatMessage.getRole();
if ("human".equals(role)) {
role = "user";
} else {
role = "model";
}
contents.add(new GeminiContentVo(role, chatMessage.getContent()));
}
}
}
GeminiChatRequestVo reqVo = new GeminiChatRequestVo(contents);
// 2. 将 Prompt 塞到 role = "model" 的内容中,根据不同的 focusMode 构造不同的内容
String focusMode = reqMessageVo.getFocusMode();
if (FocusMode.webSearch.equals(focusMode)) {
if (inputPrompt != null) {
GeminiPartVo part = new GeminiPartVo(inputPrompt);
GeminiContentVo system = new GeminiContentVo("model", Collections.singletonList(part));
contents.add(system);
}
// 再将用户问题以 role = "user" 的形式添加,并要求回复使用同一语言
contents.add(new GeminiContentVo("user", content + ". You must reply using the my this message language."));
} else if (FocusMode.translator.equals(focusMode)) {
GeminiPartVo geminiPartVo = new GeminiPartVo(inputPrompt);
GeminiSystemInstructionVo geminiSystemInstructionVo = new GeminiSystemInstructionVo();
geminiSystemInstructionVo.setParts(geminiPartVo);
reqVo.setSystem_instruction(geminiSystemInstructionVo);
contents.add(new GeminiContentVo("user", content));
log.info("json:{}", JsonUtils.toSkipNullJson(reqVo));
}
// 5. 向前端通知一个空消息,标识搜索结束,开始推理
ChatWsRespVo<String> chatVo = ChatWsRespVo.message(answerMessageId, "");
WebSocketResponse websocketResponse = WebSocketResponse.fromJson(chatVo);
if (channelContext != null) {
Tio.bSend(channelContext, websocketResponse);
}
long start = System.currentTimeMillis();
// 6. 流式/一次性获取结果
Call call = null;
if (channelContext != null) {
Callback callback = new SearchGeminiSseCallback(channelContext, sessionId, quesitonMessageId, answerMessageId, start);
call = GeminiClient.stream(GoogleGeminiModels.GEMINI_2_0_FLASH_EXP, reqVo, callback);
} else {
GeminiChatResponseVo vo = GeminiClient.generate(GoogleGeminiModels.GEMINI_2_0_FLASH_EXP, reqVo);
log.info(vo.getCandidates().get(0).getContent().getParts().get(0).getText());
}
return call;
}
}
说明
- 根据当前对话历史构建 Gemini 请求内容,分别以
"user"
和"model"
的角色构造上下文。 - 根据不同的
focusMode
(如 webSearch 或 translator)设置不同的系统指令或提示词。 - 在请求发送前,通过 WebSocket 向客户端发送一个空消息以通知客户端搜索结束,进入推理阶段。
- 最后调用 GeminiClient 的
stream
方法发起流式请求,并传入回调处理响应数据。
6. 回调类
回调类用于处理第三方模型接口(如 Gemini、DeepSeek)的异步响应,将流式响应数据实时发送给客户端。
6.1 DeepSeekSseCallback
该回调类用于处理 DeepSeek 模型的 SSE(Server Sent Event)响应。
package com.litongjava.perplexica.callback;
import java.io.IOException;
import java.util.List;
import com.jfinal.kit.Kv;
import com.litongjava.openai.chat.ChatResponseDelta;
import com.litongjava.openai.chat.Choice;
import com.litongjava.openai.chat.OpenAiChatResponseVo;
import com.litongjava.perplexica.can.ChatWsStreamCallCan;
import com.litongjava.perplexica.model.PerplexicaChatMessage;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.http.server.util.SseEmitter;
import com.litongjava.tio.utils.json.FastJson2Utils;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.BufferedSource;
@Slf4j
public class DeepSeekSseCallback implements Callback {
private ChannelContext channelContext;
private Long sessionId;
private Long messageId;
private Long answerMessageId;
private Long start;
public DeepSeekSseCallback(ChannelContext channelContext, Long sessionId, Long messageId, Long answerMessageId, Long start) {
this.channelContext = channelContext;
this.sessionId = sessionId;
this.messageId = messageId;
this.answerMessageId = answerMessageId;
this.start=start;
}
@Override
public void onFailure(Call call, IOException e) {
ChatWsRespVo<String> error = ChatWsRespVo.error("CHAT_ERROR", e.getMessage());
WebSocketResponse packet = WebSocketResponse.fromJson(error);
Tio.bSend(channelContext, packet);
ChatWsStreamCallCan.remove(sessionId + "");
SseEmitter.closeSeeConnection(channelContext);
}
@Override
public void onResponse(Call call, Response response) throws IOException {
if (!response.isSuccessful()) {
String string = response.body().string();
String message = "Chat model response an unsuccessful message:" + string;
log.error("message:{}", message);
ChatWsRespVo<String> data = ChatWsRespVo.error("STREAM_ERROR", message);
WebSocketResponse webSocketResponse = WebSocketResponse.fromJson(data);
Tio.bSend(channelContext, webSocketResponse);
return;
}
try (ResponseBody responseBody = response.body()) {
if (responseBody == null) {
String message = "response body is null";
log.error(message);
ChatWsRespVo<String> data = ChatWsRespVo.progress(message);
WebSocketResponse webSocketResponse = WebSocketResponse.fromJson(data);
Tio.bSend(channelContext, webSocketResponse);
return;
}
StringBuffer completionContent = onResponseSuccess(channelContext, answerMessageId, start, responseBody);
// 保存生成的回答消息
new PerplexicaChatMessage().setId(answerMessageId).setChatId(sessionId)
.setRole("assistant").setContent(completionContent.toString())
.save();
Kv end = Kv.by("type", "messageEnd").set("messageId", answerMessageId);
Tio.bSend(channelContext, WebSocketResponse.fromJson(end));
long endTime = System.currentTimeMillis();
log.info("finish llm in {} (ms)", (endTime - start));
if (completionContent != null && !completionContent.toString().isEmpty()) {
// 可在此处将回答保存到历史记录
}
}
ChatWsStreamCallCan.remove(sessionId + "");
}
/**
* 处理 ChatGPT 成功响应
*
* @param channelContext 通道上下文
* @param responseBody 响应体
* @return 完整内容
* @throws IOException
*/
public StringBuffer onResponseSuccess(ChannelContext channelContext, Long answerMessageId, Long start, ResponseBody responseBody) throws IOException {
StringBuffer completionContent = new StringBuffer();
BufferedSource source = responseBody.source();
String line;
while ((line = source.readUtf8Line()) != null) {
if (line.length() < 1) {
continue;
}
// 处理数据行
if (line.length() > 6) {
String data = line.substring(6);
if (data.endsWith("}")) {
OpenAiChatResponseVo chatResponse = FastJson2Utils.parse(data, OpenAiChatResponseVo.class);
List<Choice> choices = chatResponse.getChoices();
if (!choices.isEmpty()) {
ChatResponseDelta delta = choices.get(0).getDelta();
String part = delta.getContent();
if (part != null && !part.isEmpty()) {
completionContent.append(part);
ChatWsRespVo<String> vo = ChatWsRespVo.message(answerMessageId, part);
Tio.bSend(channelContext, WebSocketResponse.fromJson(vo));
}
String reasoning_content = delta.getReasoning_content();
if (reasoning_content != null && !reasoning_content.isEmpty()) {
ChatWsRespVo<String> vo = ChatWsRespVo.message(answerMessageId, reasoning_content);
Tio.bSend(channelContext, WebSocketResponse.fromJson(vo));
}
}
} else if (": keep-alive".equals(line)) {
ChatWsRespVo<String> vo = ChatWsRespVo.keepAlive(answerMessageId);
WebSocketResponse websocketResponse = WebSocketResponse.fromJson(vo);
if (channelContext != null) {
Tio.bSend(channelContext, websocketResponse);
}
} else {
log.info("Data does not end with }:{}", line);
// 例如:{"type":"messageEnd","messageId":"654b8bdb25e853"}
}
}
}
return completionContent;
}
}
说明
- 该回调类在
onResponse
方法中实时读取响应流中的数据行,解析出生成的文本(包括内容和推理部分),并通过 WebSocket 实时将数据分段返回给客户端。 - 当响应结束后,将完整的生成回答保存到数据库,并通知客户端消息结束(
messageEnd
)。
6.2 SearchGeminiSseCallback
该回调类用于处理 Gemini 模型的 SSE 响应,与 DeepSeekSseCallback 类似,但针对 Gemini 模型返回的数据格式进行了解析。
package com.litongjava.perplexica.callback;
import java.io.IOException;
import java.util.List;
import com.jfinal.kit.Kv;
import com.litongjava.gemini.GeminiCandidateVo;
import com.litongjava.gemini.GeminiChatResponseVo;
import com.litongjava.gemini.GeminiContentResponseVo;
import com.litongjava.gemini.GeminiPartVo;
import com.litongjava.perplexica.can.ChatWsStreamCallCan;
import com.litongjava.perplexica.model.PerplexicaChatMessage;
import com.litongjava.perplexica.vo.ChatWsRespVo;
import com.litongjava.tio.core.ChannelContext;
import com.litongjava.tio.core.Tio;
import com.litongjava.tio.http.server.util.SseEmitter;
import com.litongjava.tio.utils.json.FastJson2Utils;
import com.litongjava.tio.websocket.common.WebSocketResponse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.BufferedSource;
@Slf4j
public class SearchGeminiSseCallback implements Callback {
private ChannelContext channelContext;
private Long chatId;
private Long quesitonMessageId;
private Long answerMessageId;
private long start;
public SearchGeminiSseCallback(ChannelContext channelContext, Long sessionId, Long messageId, Long answerMessageId, long start) {
this.channelContext = channelContext;
this.chatId = sessionId;
this.quesitonMessageId = messageId;
this.answerMessageId = answerMessageId;
this.start = start;
}
@Override
public void onFailure(Call call, IOException e) {
ChatWsRespVo<String> error = ChatWsRespVo.error("CHAT_ERROR", e.getMessage());
WebSocketResponse packet = WebSocketResponse.fromJson(error);
Tio.bSend(channelContext, packet);
ChatWsStreamCallCan.remove(chatId.toString());
SseEmitter.closeSeeConnection(channelContext);
}
@Override
public void onResponse(Call call, Response response) throws IOException {
if (!response.isSuccessful()) {
String message = "Chat model response an unsuccessful message:" + response.body().string();
log.error("message:{}", message);
ChatWsRespVo<String> data = ChatWsRespVo.error("STREAM_ERROR", message);
WebSocketResponse webSocketResponse = WebSocketResponse.fromJson(data);
Tio.bSend(channelContext, webSocketResponse);
return;
}
try (ResponseBody responseBody = response.body()) {
if (responseBody == null) {
String message = "response body is null";
log.error(message);
ChatWsRespVo<String> data = ChatWsRespVo.progress(message);
WebSocketResponse webSocketResponse = WebSocketResponse.fromJson(data);
Tio.bSend(channelContext, webSocketResponse);
return;
}
StringBuffer completionContent = onResponse(channelContext, answerMessageId, start, responseBody);
// 保存生成的回答消息到数据
new PerplexicaChatMessage().setId(answerMessageId).setChatId(chatId)
.setRole("assistant").setContent(completionContent.toString())
.save();
Kv end = Kv.by("type", "messageEnd").set("messageId", answerMessageId);
Tio.bSend(channelContext, WebSocketResponse.fromJson(end));
long endTime = System.currentTimeMillis();
log.info("finish llm in {} (ms)", (endTime - start));
if (completionContent != null && !completionContent.toString().isEmpty()) {
// 可在此处进行一些后置处理
}
}
ChatWsStreamCallCan.remove(chatId.toString());
}
/**
* 处理 Gemini 成功响应
*
* @param channelContext 通道上下文
* @param responseBody 响应体
* @return 完整内容
* @throws IOException
*/
public StringBuffer onResponse(ChannelContext channelContext, Long answerMessageId, Long start, ResponseBody responseBody) throws IOException {
StringBuffer completionContent = new StringBuffer();
BufferedSource source = responseBody.source();
String line;
while ((line = source.readUtf8Line()) != null) {
if (line.length() < 1) {
continue;
}
// 处理数据行
if (line.length() > 6) {
String data = line.substring(6);
if (data.endsWith("}")) {
GeminiChatResponseVo chatResponse = FastJson2Utils.parse(data, GeminiChatResponseVo.class);
List<GeminiCandidateVo> candidates = chatResponse.getCandidates();
if (!candidates.isEmpty()) {
GeminiContentResponseVo content = candidates.get(0).getContent();
List<GeminiPartVo> parts = content.getParts();
GeminiPartVo geminiPartVo = parts.get(0);
String text = geminiPartVo.getText();
if (text != null && !text.isEmpty()) {
completionContent.append(text);
ChatWsRespVo<String> vo = ChatWsRespVo.message(answerMessageId, text);
Tio.bSend(channelContext, WebSocketResponse.fromJson(vo));
}
}
} else {
log.info("Data does not end with }:{}", line);
// 例如:{"type":"messageEnd","messageId":"654b8bdb25e853"}
}
}
}
return completionContent;
}
}
说明
- SearchGeminiSseCallback 解析 Gemini 模型返回的 SSE 数据,读取每一行数据,并解析其中的 JSON 数据,提取生成的文本片段。
- 每解析出一部分内容,就立即通过 WebSocket 发送给客户端,保证响应的实时性。
- 最后,将完整回答保存到数据库,并通知客户端消息结束。
总结
本文档详细展示了 ai-search 搜索代码的完整实现,基于 WebSocket 实现了以下主要功能:
- 实体类:定义了消息、响应、历史记录、搜索结果等数据结构。
- WebSocket 处理器:管理客户端连接、消息发送与关闭等流程。
- 消息分发服务(WsChatService):实现了消息处理、历史记录管理、问题重写、搜索调用和推理生成的逻辑。
- 搜索服务(AiSearchService):调用第三方搜索引擎获取搜索结果,并构造提示词。
- 预测服务(GeminiPredictService):调用 Gemini 模型进行生成,并通过流式返回回答。
- 回调处理:通过 DeepSeekSseCallback 和 SearchGeminiSseCallback 两个回调类实现对第三方响应的实时解析与客户端通知。
通过上述各模块协同工作,系统能够在收到用户请求后进行问题重写、搜索、结果筛选、提示词生成以及最终回答生成,并以流式响应的方式实时返回给客户端,确保了搜索与推理功能的高效性和实时性。