基于 Kashgari 2 的短文本分类: 模型部署

文本分类是自然语言处理核心任务之一,常见用文本审核、广告过滤、情感分析、语音控制和反黄识别等NLP领域。本文如何使用 FastAPI 和 TensorFlow Serving 部署训练好的模型。

使用 FastAPI 部署

FastAPI 安装和使用

FastAPI 是一个用于构建API 的现代、快速(高性能)的web 框架。使用方法和大名鼎鼎的 Flask 很相似,但提供了更好的性能,更加完善的文档。

首先我们安装 FastAPI 和 uvicorn(用于运行服务)。

1
pip install fastapi uvicorn

创建 Hello World API

安装完成后,我们创建 app.py 文件,定义一个 GET say_hello 接口。代码如下:

1
2
3
4
5
6
7
8
9
from fastapi import FastAPI

app = FastAPI()

# 定义 say_hello 方法,接受一个 name 参数
# 然后用 app.get 装饰器注册一个 GET /say_hello 接口
@app.get("/say_hello")
def say_hello(name: str):
return {"message": f"Hello, {name}"}

定义完成后,我们在终端执行以下命令 uvicorn app:app --reload 启动 API。可以看到以下提示,说明 API 启动了。

1
2
3
4
5
6
$ uvicorn app:app --reload
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [4868] using statreload
INFO: Started server process [4870]
INFO: Waiting for application startup.
INFO: Application startup complete.

启动完成后,我们可以在浏览器打开 http://127.0.0.1:8000/docs 看到自动生成的在线文档。这个在线文档可以用于查看 API 定义和在线调试。

在线文档

使用 FastAPI 部署模型

使用 FastAPI 部署模型也很简单,一样定义一个方法,输入一句话,输出其类别即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Dict, Any

import jieba
from fastapi import FastAPI
from kashgari.layers import L
from kashgari.tasks.classification.abc_model import ABCClassificationModel
from tensorflow import keras

app = FastAPI()


# ----------- 由于我们加载的是自定义模型,Serve 的时候还需要再引入一次 ----------------
class Double_BiLSTM_Model(ABCClassificationModel):
@classmethod
def default_hyper_parameters(cls) -> Dict[str, Any]:
# 定义超参,单独列出来是这样定义后,后续可以在不改变模型定义的情况下更新超参数
return {
'lstm1_units': 256,
'lstm2_units': 128,
'dropout_rate': 0.5
}

def build_model_arc(self) -> None:
config = self.hyper_parameters
output_dim = self.label_processor.vocab_size
embed_model = self.embedding.embed_model

# 定义模型架构
self.tf_model = keras.Sequential([
embed_model,
L.Bidirectional(L.LSTM(config['lstm1_units'], return_sequences=True)),
L.Bidirectional(L.LSTM(config['lstm2_units'], return_sequences=False)),
L.Dropout(config['dropout_rate']),
L.Dense(output_dim),
self._activation_layer()
])
# ----------- 由于我们加载的是自定义模型,Serve 的时候还需要再引入一次 ----------------

# 一定要先加载模型,不能每个请求在初始化一次
model = Double_BiLSTM_Model.load_model('best_model')


@app.get("/predict")
def predict(sentence: str):
# 需要和此前逻辑一样初始化模型
seg_sentence = list(jieba.cut(sentence))
# 模型输入多个 sample,但我们此处只用一个,所以要取出
results = model.predict([seg_sentence])
return {"label": results[0]}

再次使用启动命令 uvicorn app:app 启动接口,由于模型比较大加载慢,此时我们去掉自动重载功能,避免处于反复重新加载。

启动接口后,我们继续浏览器打开 http://127.0.0.1:8000/docs 进行调试。第一个请求会比较慢,第一个请求处理完毕后,后续请求都可以很快响应。以下是测试结果。

输入 结果 耗时
美国加州抗议特朗普税务疑团游行变暴力冲突 news_world 37ms
热火曾故意隐藏对邓肯-罗宾逊的兴趣 news_sports 38ms
28省份今举行中小学教师资格考试 有啥新变化? news_edu 35ms
苏格兰影星康纳利去世 老照片再现艺术人生高光时刻 news_entertainment 36ms
疫情中的万圣节:新规矩、新点子和老传统 news_world 40ms
油价暴跌、美股熔断:全球经济进入下行周期? news_finance 37ms

使用 TF-serving 部署

直接使用 FastAPI 部署虽然简单快捷,但是性能不是很好,不适合生产环境使用。生产环境通常使用 FastAPI + TF-serving 的部署模式。具体做法是,使用 FastAPI 提供对外的接口,使用 TF-Serving 处理模型推理过程。

转换模型

首先我们需要把模型转换为 SavedModel 格式,Kashgari 内置了转换接口,可以很方便的转换。

1
2
3
4
5
from kashgari.utils import convert_to_saved_model

model = Double_BiLSTM_Model.load_model('best_model')
# 保存转换后模型到 tf_serving_model/news 目录下,版本号为 1
convert_to_saved_model(model, 'tf_serving_model/news', version=1)

转换成功后,新的模型目录结构如下:

1
2
3
4
5
6
7
8
9
10
tf_serving_model/
└── news
└── 1
├── assets
├── model_config.json
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index

TF-Serving 接口

TensorFlow Serving是google提供的一种生产环境高性能部署方案,它可以将训练好的机器学习模型部署到线上,提供 gRPC 和 restful 接口。使用 docker 启动 TF-Serving 最方便,在终端执行以下命令启动 tf-serving。

1
docker run -t --rm -p 8501:8501 -v "`pwd`/tf_serving_model:/models" -e MODEL_NAME=news tensorflow/serving

此处我们挂载了本地的 tf_serving_model 到容器内的 models 目录下,然后通过环境变量设定需要加载的模型是 news。

启动完成后,我们通过 http://localhost:8501/v1/models/news 检查模型状态,以下响应表示模型加载成功。

1
2
3
4
5
6
7
8
9
10
11
12
{
"model_version_status": [
{
"version": "1",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}

FastAPI 接口

虽然 TF-Serving 已经启动了模型,但是模型输入输出均为张量,所以我们还需要处理输入为张量,然后再把模型输出转换为具体的标签。这步骤我们将在 FastAPI 接口内完成。

新建一个 tf-api.py 文件,写一个新的接口来处理输入输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import jieba
import numpy as np
import requests
from fastapi import FastAPI
from kashgari.processors import load_processors_from_model

app = FastAPI()
# 从保存模型中加载输入和输出处理类
text_processor, label_processor = load_processors_from_model('tf_serving_model/news/1')


@app.get("/predict_with_tf")
def predict_with_tf(sentence: str):
seg_sentence = list(jieba.cut(sentence))
# 此步骤将分词后的输入转换成张量
tensor = text_processor.transform([seg_sentence])
# 张量格式转换为 TF-Serving 接受的格式
instances = [i.tolist() for i in tensor]
# 使用 requests 框架发送请求进行推理
r = requests.post("http://localhost:8501/v1/models/news:predict",
json={"instances": instances})
# 获取推理结果概率
predictions = r.json()['predictions']

# 使用标签处理器将标签 index 转换成具体的标签
labels = label_processor.inverse_transform(np.array(predictions).argmax(-1))
return {"label": labels[0]}

接下来终端使用命令 uvicorn tf-api:app --port 4001 启动 API。此时选择 4001 接口是为了防止和之前的 API 冲突。测试后可以看到 TF-Serving 能大幅度提升性能,耗时只有纯 FastAPI 方案的 1/3。

输入 结果 耗时
美国加州抗议特朗普税务疑团游行变暴力冲突 news_world 11ms
热火曾故意隐藏对邓肯-罗宾逊的兴趣 news_sports 12ms
28省份今举行中小学教师资格考试 有啥新变化? news_edu 12ms
苏格兰影星康纳利去世 老照片再现艺术人生高光时刻 news_entertainment 10ms
疫情中的万圣节:新规矩、新点子和老传统 news_world 12ms
油价暴跌、美股熔断:全球经济进入下行周期? news_finance 12ms