准备工作,先准备 python 环境,下载 BERT 语言模型
- Python 3.6 环境
需要安装kashgari
Backend
pypi version desc
TensorFlow 2.x
pip install ‘kashgari>=2.0.0’ coming soon
TensorFlow 1.14+
pip install ‘kashgari>=1.0.0,<2.0.0’ current version
Keras
pip install ‘kashgari<1.0.0’ legacy version
- BERT, Chinese 中文模型
我选择的是工大的BERT-wwm-ext
模型
在此感谢上述作者
数据集准备
from kashgari.corpus import ChineseDailyNerCorpus
train_x, train_y = ChineseDailyNerCorpus.load_data('train')
valid_x, valid_y = ChineseDailyNerCorpus.load_data('validate')
test_x, test_y = ChineseDailyNerCorpus.load_data('test')
print(f"train data count: {len(train_x)}")
print(f"validate data count: {len(valid_x)}")
print(f"test data count: {len(test_x)}")
train data count: 20864
validate data count: 2318
test data count: 4636
采用人民日报标注的数据集,格式为:
海 O
钓 O
比 O
赛 O
地 O
点 O
在 O
厦 B-LOC
门 I-LOC
与 O
金 B-LOC
门 I-LOC
之 O
间 O
的 O
海 O
域 O
。 O
创建 BERT embedding
import kashgari
from kashgari.embeddings import BERTEmbedding
bert_embed = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12',
task=kashgari.LABELING,
sequence_length=100)
创建模型并训练
from kashgari.tasks.labeling import BiLSTM_CRF_Model
# 还可以选择 `CNN_LSTM_Model`, `BiLSTM_Model`, `BiGRU_Model` 或 `BiGRU_CRF_Model`
model = BiLSTM_CRF_Model(bert_embed)
model.fit(train_x,
train_y,
x_validate=valid_x,
y_validate=valid_y,
epochs=20,
batch_size=512)
model.save('ner.h5')
模型评估
model.evaluate(test_x, test_y)
BERT + B-LSTM-CRF 模型效果最好。详细得分如下:
precision
recall
f1-score
support
LOC
0.9208
0.9324
0.9266
ORG
0.8728
0.8882
0.8804
PER
0.9622
0.9633
0.9627
avg / total
0.9169
0.9271
0.9220
模型使用
# -*- coding: utf-8 -*-
import kashgari
import re
loaded_model = kashgari.utils.load_model('per_ner.h5')
def cut_text(text, lenth):
textArr = re.findall('.{' + str(lenth) + '}', text)
textArr.append(text[(len(textArr) * lenth):])
return textArr
def extract_labels(text, ners):
ner_reg_list = []
if ners:
new_ners = []
for ner in ners:
new_ners += ner;
for word, tag in zip([char for char in text], new_ners):
if tag != 'O':
ner_reg_list.append((word, tag))
# 输出模型的NER识别结果
labels = {}
if ner_reg_list:
for i, item in enumerate(ner_reg_list):
if item[1].startswith('B'):
label = ""
end = i + 1
while end <= len(ner_reg_list) - 1 and ner_reg_list[end][1].startswith('I'):
end += 1
ner_type = item[1].split('-')[1]
if ner_type not in labels.keys():
labels[ner_type] = []
label += ''.join([item[0] for item in ner_reg_list[i:end]])
labels[ner_type].append(label)
return labels
while True:
text_input = input('sentence: ')
texts = cut_text(text_input, 100)
ners = loaded_model.predict([[char for char in text] for text in texts])
print(ners)
labels = extract_labels(text_input, ners)
print(labels)
参考文献
Chinese-BERT-wwm:https://github.com/ymcui/Chinese-BERT-wwm
Kashgari:https://github.com/BrikerMan/Kashgari