离线运行向量模型
本章节介绍如何离线运行向量模型 jina-embeddings,并编写一个 HTTP API 服务,用于使用 Jina-embeddings-v3 模型生成文本嵌入。该服务基于 ONNX Runtime 和 Robyn 框架构建,兼容 OpenAI embedding 数据格式,并支持在 CPU 上运行。
依赖
- Python: 3.8 及以上版本
- 依赖库: numpy, onnxruntime, robyn, transformers
安装步骤
1. 创建 Conda 环境
首先创建一个 Conda 环境,并安装 Python 3.9(或更高版本):
conda create -n jina-embeddings python=3.9 -y
conda activate jina-embeddings
2. 安装所需依赖库
使用 pip 安装项目所需的所有依赖库:
pip install numpy onnxruntime robyn transformers
编写代码
将以下代码保存为 server.py
:
import argparse
import json
import numpy as np
import onnxruntime as ort
import asyncio
from robyn import Robyn, Request
from transformers import AutoTokenizer, PretrainedConfig
def mean_pooling(model_output: np.ndarray, attention_mask: np.ndarray):
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.broadcast_to(input_mask_expanded, token_embeddings.shape)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
return sum_embeddings / sum_mask
# 加载 tokenizer 与模型配置
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v3', trust_remote_code=True)
config = PretrainedConfig.from_pretrained('jinaai/jina-embeddings-v3')
# 全局变量 session,后续在 main() 中进行初始化
session = None
def initialize_session(model_path: str):
global session
session = ort.InferenceSession(model_path)
print(f"Model loaded from: {model_path}")
def encode_single_text(text, task="text-matching", max_length=8192):
# 使用 tokenizer 对输入文本进行编码
input_text = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors='np')
task_index = config.lora_adaptations.index(task)
inputs = {
'input_ids': input_text['input_ids'],
'attention_mask': input_text['attention_mask'],
'task_id': np.array(task_index, dtype=np.int64) # 标量数值
}
# 进行模型推理(阻塞调用)
outputs = session.run(None, inputs)[0]
embeddings = mean_pooling(outputs, input_text["attention_mask"])
norm = np.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
embeddings = embeddings / norm
return embeddings[0] # 返回 1D 向量
async def encode_single_text_async(text, task="text-matching", max_length=8192):
# 在单独的线程中运行 encode_single_text,避免阻塞事件循环
return await asyncio.to_thread(encode_single_text, text, task, max_length)
async def encode_texts_async(texts, task="text-matching", max_length=8192):
# 并发地对多个文本进行编码
tasks = [encode_single_text_async(text, task, max_length) for text in texts]
embeddings = await asyncio.gather(*tasks)
return np.array(embeddings)
# 使用 Robyn 构建 HTTP 服务
app = Robyn(__file__)
@app.get("/")
async def index_endpoint(request):
return "/"
@app.post("/v1/embeddings")
async def embeddings_endpoint(request: Request):
try:
# Robyn 的 request.json() 返回一个 dict,无需 await
data = request.json()
if "input" not in data:
return {"error": "Missing 'input' field."}, {}, 400
texts = data["input"]
task = data.get("task", "text-matching")
max_length = data.get("max_length", 8192)
# 对单个文本输入也支持并发处理
if isinstance(texts, str):
texts = [texts]
# 统计 token 数量
total_tokens = 0
for text in texts:
tokens = tokenizer.tokenize(text)
total_tokens += len(tokens)
embeddings = await encode_texts_async(texts, task=task, max_length=max_length)
response = {
"object": "list",
"model": "jina-embeddings-v3",
"data": [
{
"object": "embedding",
"index": i,
"embedding": emb.tolist()
} for i, emb in enumerate(embeddings)
],
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens
}
}
return response, {}, 200
except Exception as e:
return {"error": str(e)}, {}, 500
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Jina Embeddings v3 HTTP Service")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the ONNX model file."
)
parser.add_argument(
"--port",
type=int,
default=10002,
help="Port number to run the service on. (default: 10002)"
)
args = parser.parse_args()
# 根据提供的模型路径初始化 ONNX session
initialize_session(args.model_path)
# 在指定端口启动 HTTP 服务
app.start(port=args.port)
# 示例 curl 请求:
# curl -X POST http://localhost:10002/v1/embeddings \
# -H "Content-Type: application/json" \
# -d '{"input": ["Hello world", "你好,世界"], "task": "text-matching"}'
使用方法
1. 下载模型
运行以下命令下载模型文件,确保目录结构正确:
mkdir -p ~/models/jina-embeddings-v3/onnx
cd ~/models/jina-embeddings-v3/onnx
wget https://huggingface.co/jinaai/jina-embeddings-v3/resolve/main/onnx/model.onnx
wget https://huggingface.co/jinaai/jina-embeddings-v3/resolve/main/onnx/model.onnx_data
2. 运行服务器
执行以下命令启动服务器(注意:请将 ~/models/jina-embeddings-v3/onnx/model.onnx
替换为实际模型路径):
python server.py --model_path ~/models/jina-embeddings-v3/onnx/model.onnx
3. 向 /v1/embeddings
接口发送 POST 请求
示例请求(支持文本列表与单个文本输入):
curl -X POST http://localhost:10002/v1/embeddings \
-H "Content-Type: application/json" \
-d '{"input": ["Hello world", "你好,世界"]}'
返回示例(注意:实际返回的 "model"
字段可能因模型名称而有所不同,下例仅供参考):
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": []
}
],
"model": "text-embedding-3-large",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
Docker
为了容器化服务,可以使用以下 Dockerfile
构建 Docker 镜像:
FROM python:3.9-slim
# 安装 wget(用于下载模型)
RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
# 下载模型文件到指定目录
RUN mkdir -p /models/jina-embeddings-v3/onnx && \
cd /models/jina-embeddings-v3/onnx && \
wget https://huggingface.co/jinaai/jina-embeddings-v3/resolve/main/onnx/model.onnx && \
wget https://huggingface.co/jinaai/jina-embeddings-v3/resolve/main/onnx/model.onnx_data
WORKDIR /app
# 将应用代码复制到容器中
COPY . /app
# 安装所需的 Python 包
RUN pip install --no-cache-dir numpy onnxruntime robyn transformers
# 指定默认启动命令及模型路径
CMD ["python", "server.py", "--model_path", "/models/jina-embeddings-v3/onnx/model.onnx"]
构建 Docker 镜像
在项目根目录下执行以下命令构建 Docker 镜像:
docker build -t litongjava/py-jina-embeddings-server:1.0.0 .
运行 Docker 容器
使用以下命令运行容器并映射端口 10002:
docker run -p 10002:10002 litongjava/py-jina-embeddings-server:1.0.0
或者使用以下命令启动带有重启策略的容器:
docker run -dit --restart=always --name=jina-embeddings-server -p 10002:10002 litongjava/py-jina-embeddings-server:1.0.0
Java 程序接入
package com.litongjava.maxkb.embedding;
import org.junit.Test;
import com.litongjava.openai.client.OpenAiClient;
import com.litongjava.openai.embedding.EmbeddingRequestVo;
import com.litongjava.openai.embedding.EmbeddingResponseVo;
import com.litongjava.tio.utils.json.JsonUtils;
public class JinaEmbeddingTest {
@Test
public void testEmbedding() {
String input = "你好世界";
String url = "http://192.168.3.9:10002/v1";
// 因为调用的是本地模型可以随便写
String apiKey = "1234";
EmbeddingRequestVo embeddingRequestVo = new EmbeddingRequestVo(input);
EmbeddingResponseVo embeddings = OpenAiClient.embeddings(url, apiKey, embeddingRequestVo);
System.out.println(JsonUtils.toSkipNullJson(embeddings));
}
}