基于 Kashgari 2 的短文本分类: 多标签分类

本文介绍如何使用 Kashgari 训练多标签分类模型和如何部署。

数据集准备

首先我们要预处理数据集为 Kashgari 接受的格式,对于多标签,Kashgari 数据集中每个样本为以下格式。

1
2
x = ['谢', '娜', '为', '李', '浩', '菲', '澄', '清', '网', '络', '谣', '言']
y = ['新闻', '娱乐']

按照上述格式准备好数据集 train_x, train_y, test_x, test_y 后就可以训练了。

模型训练

多标签模型的训练和评估方法和 基于 Kashgari 2 的短文本分类: 训练模型和调 一文中的一致,只是需要在模型初始化的时候多传一个参数。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 所有的 classification 模型支持多标签
from kashgari.tasks.classification import BiLSTM_Model

# 初始化模型时候多传一个参数 multi_label=True 即可处理多标签数据集
base_model = BiLSTM_Model(multi_label=True)

base_history = base_model.fit(train_x,
train_y,
batch_size=128,
epochs=10)

# 评估方法和以前一致
base_report = base_model.evaluate(test_x, test_y)

# 预测方法也和以前一致
labels = base_model.predict(test_x)

# 保存模型
base_model.save("multi_label_model")

模型部署

模型部署方式和 基于 Kashgari 2 的短文本分类: 模型部署 步骤一致,以 TF Serving + FastAPI 为例,具体步骤如下。

转换模型

首先转换模型为 tf-serving 模型。

1
2
3
4
5
from kashgari.utils import convert_to_saved_model

model = BiLSTM_Model.load_model('multi_label_model')
# 保存转换后模型到 tf_serving_model/multi_label 目录下,版本号为 1
convert_to_saved_model(model, 'tf_serving_model/multi_label', version=1)

启动 TF-Serving 接口

用 docker 方式启动 TF serving,注意需要更新 MODEL_NAME 参数为 multi_label

1
docker run -t --rm -p 8501:8501 -v "`pwd`/tf_serving_model:/models" -e MODEL_NAME=multi_label tensorflow/serving

启动完成后,我们通过 http://localhost:8501/v1/models/multi_label 确定运行状态。

注意:由于模型名称不一样,所以我们所有请求 url 中模型名字和上一篇文章不一样

FastAPI 接口

接下来配置 FastAPI 接口,新建一个 tf-api.py 文件,定义接口提供对外的预测服务。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import requests
from fastapi import FastAPI
from kashgari.processors import load_processors_from_model

app = FastAPI()
# 从保存模型中加载输入和输出处理类
text_processor, label_processor = load_processors_from_model('tf_serving_model/news/1')


@app.get("/predict_with_tf")
def predict_with_tf(sentence: str):
# 这个实验我们没有用 jieba,直接按照字进行 tokenize,所以直接转数组即可。
seg_sentence = list(sentence)
# 此步骤将分词后的输入转换成张量
tensor = text_processor.transform([seg_sentence])
# 张量格式转换为 TF-Serving 接受的格式
instances = [i.tolist() for i in tensor]
# 使用 requests 框架发送请求进行推理
r = requests.post("http://localhost:8501/v1/models/news:predict",
json={"instances": instances})
# 获取推理结果概率
predictions = r.json()['predictions']

# 使用标签处理器将标签 index 转换成具体的标签
labels = label_processor.inverse_transform(np.array(predictions).argmax(-1))
return {"label": labels[0]}

接下来终端使用命令 uvicorn tf-api:app --port 4002 启动 API 即可。