本文介绍如何使用 Kashgari 训练多标签分类模型和如何部署。
数据集准备
首先我们要预处理数据集为 Kashgari 接受的格式,对于多标签,Kashgari 数据集中每个样本为以下格式。
1 2
| x = ['谢', '娜', '为', '李', '浩', '菲', '澄', '清', '网', '络', '谣', '言'] y = ['新闻', '娱乐']
|
按照上述格式准备好数据集 train_x, train_y, test_x, test_y
后就可以训练了。
模型训练
多标签模型的训练和评估方法和 基于 Kashgari 2 的短文本分类: 训练模型和调 一文中的一致,只是需要在模型初始化的时候多传一个参数。例如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| from kashgari.tasks.classification import BiLSTM_Model
base_model = BiLSTM_Model(multi_label=True)
base_history = base_model.fit(train_x, train_y, batch_size=128, epochs=10)
base_report = base_model.evaluate(test_x, test_y)
labels = base_model.predict(test_x)
base_model.save("multi_label_model")
|
模型部署
模型部署方式和 基于 Kashgari 2 的短文本分类: 模型部署 步骤一致,以 TF Serving + FastAPI 为例,具体步骤如下。
转换模型
首先转换模型为 tf-serving 模型。
1 2 3 4 5
| from kashgari.utils import convert_to_saved_model
model = BiLSTM_Model.load_model('multi_label_model')
convert_to_saved_model(model, 'tf_serving_model/multi_label', version=1)
|
启动 TF-Serving 接口
用 docker 方式启动 TF serving,注意需要更新 MODEL_NAME
参数为 multi_label
。
1
| docker run -t --rm -p 8501:8501 -v "`pwd`/tf_serving_model:/models" -e MODEL_NAME=multi_label tensorflow/serving
|
启动完成后,我们通过 http://localhost:8501/v1/models/multi_label
确定运行状态。
注意:由于模型名称不一样,所以我们所有请求 url 中模型名字和上一篇文章不一样
FastAPI 接口
接下来配置 FastAPI 接口,新建一个 tf-api.py 文件,定义接口提供对外的预测服务。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| import numpy as np import requests from fastapi import FastAPI from kashgari.processors import load_processors_from_model
app = FastAPI()
text_processor, label_processor = load_processors_from_model('tf_serving_model/news/1')
@app.get("/predict_with_tf") def predict_with_tf(sentence: str): seg_sentence = list(sentence) tensor = text_processor.transform([seg_sentence]) instances = [i.tolist() for i in tensor] r = requests.post("http://localhost:8501/v1/models/news:predict", json={"instances": instances}) predictions = r.json()['predictions']
labels = label_processor.inverse_transform(np.array(predictions).argmax(-1)) return {"label": labels[0]}
|
接下来终端使用命令 uvicorn tf-api:app --port 4002
启动 API 即可。