使用 Keras 进行文本分类
上节课我们学习了如何进行图片分类,在此过程之中我们学习到了如何对图片数据进行处理;而对于文本数据我们应该如何处理与训练呢?与图片数据相比,文本数据有以下几个特点:
- 长度不确定;
- 语言之间的差异较大,编码方式各不相同;
- 同一种语言的处理方式也不尽相同;
- 特征提取方式不统一。
因为文本数据的不确定性,因此我们这节课采用最常用的数据处理方式(单词嵌入)与最常用的文本分类数据集( IMBD® 评价数据集)。
1. 数据集合概览
IMDB® 数据集合一共包含 50000 条数据,每条数据都是从 IMDB® 电影的评价中选取,同时每个评论都被归类为**“正面评价”或“负面评价”**。比如:
x: [1, 778, 128, 74, 12, 630, 163, 15, 4, 1766, 7982, 1051, 2, 32, 85, 156, 45, 40, 148, 139, 121, 664, 665, 10, 10, 1361, 173, 4, 749, 2, 16, 3804, 8, 4, 226, 65, 12, 43, 127, 24, 2, 10, 10]
y: 0
其中评论是被编码之后所得到的数组,每个英文单词对应一个固定的数字。而标签用 0 和 1 来表示“负面评价”和“证明评价”。
将上述例子还原一下就是:
x: "begins better than it ends funny that the russian submarine crew <UNK> all other actors it's like those scenes where documentary shots br br spoiler part the message <UNK> was contrary to the whole story it just does not <UNK> br br"
y: "Negative"
这 50000 条数据它们具体的分布如下:
- 训练集包含 25000 条训练数据,其中正负数据各 12500 条;
- 测试集包含 25000 条测试数据,其中正负数据各 12500 条。
换句话说,该数据集合上面的数据是**“平衡的”**,因为它包含的正样本与负样本的数目相同。
在 TensorFlow 之中,我们可以直接通过调用内部 API 的方式来获取该数据集:
(train_data, train_labels), (test_data, test_labels) = \
tf.keras.datasets.imdb.load_data(num_words=words_num)
2. 如何对文本数据进行处理
在机器学习之中,我们对于文本数据的处理大致分为以下几步:
- 数据清洗,清理掉无用的数据;
- 文本编码,将每一个单词转化为一个数字来表示;
- 将编码后的文本转化为定长表示;
- 将文本提取为特征向量进行下一步的训练。
其中在这个例子之中,我们加载的数据集合已经由 TensorFlow 进行过数据清洗与文本编码了,因此我们只需要将其转化为定长表示并且提取其特征向量即可。
2.1 如何将文本数组填充到定长
在 TensorFlow 之中我们可以采用预处理的方式来将编码后的文本转化为定长:
train_data = tf.keras.preprocessing.sequence.pad_sequences(
train_data,
value=0,
padding='post',
maxlen=10
)
其中的各个参数的解释如下:
- trian_data:我们要处理的、编码后的数据;
- maxlen:将每个文本样本处理后的长度,如果原长度不足 maxlen ,那么便会使用 value 进行填充;如果原长度超过了 maxlen ,那么便会将文本截断;
- value:用来填充文本的数字,一般我们使用0即可;
- padding:填充的模式,post 表示填充的 value 位置在原文之后。
我们举个简单的例子,如果处理前的文本数组为:
[1, 2, 3]
当我们使用上述方式填充之后的数据就会变为:
[1, 2, 3, 0, 0 ,0, 0, 0, 0, 0]
2.2 如何将文本数组进行嵌入并提取特征向向量
在 TensorFlow 之中,我们最常用的提取文本特征的网络层是:
tf.keras.layers.Embedding(vocab_size, dim),
其中 vocab_size 表示的是词汇量的总数,dim 表示特征向量的维度。
通过输入编码后的文本数组,我们可以得到该文本的特征向量(embedding vector)。
3. 模型的完整表示
当我们知道了如何对文本数据进行处理之后,我们便可以编写我们的文本分类模型的程序了。
具体的程序如下:
import tensorflow as tf
import numpy as np
# 定义基本参数
words_num = 10000
val_num = 12500
EPOCHS = 30
pad_max_length = 256
BATCH_SIZE = 64
# 获取数据
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.imdb.load_data(num_words=words_num)
word_index = tf.keras.datasets.imdb.get_word_index()
# 添加特殊字符
word_index = {k:(v+3) for k,v in word_index.items()}
word_index["<pad>"] = 0
word_index["<start>"] = 1
word_index["<unknown>"] = 2
word_index["<unused>"] = 3
# 数据预处理
train_data = tf.keras.preprocessing.sequence.pad_sequences(train_data, value=0, padding='post', maxlen=pad_max_length)
test_data = tf.keras.preprocessing.sequence.pad_sequences(test_data, value=0, padding='post', maxlen=pad_max_length)
# 划分训练集合与验证集合
x_val, x_train = train_data[:val_num], train_data[val_num:]
y_val, y_train = train_labels[:val_num], train_labels[val_num:]
# 模型构建
model = tf.keras.Sequential([
tf.keras.layers.Embedding(words_num, 32),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.summary()
# 编译模型
model.compile(optimizer='adam', loss='binary\_crossentropy', metrics=['accuracy'])
# 训练
history = model.fit(x_train, y_train, epochs=EPOCHS,
batch_size=BATCH_SIZE, validation_data=(x_val, y_val))
# 测试
results = model.evaluate(test_data, test_labels)
print(results)
在该程序之中有几个需要注意的地方:
- 在添加特殊字符字符处我们添加了四个特殊字符,其中
- 0 表示填充所使用的字符;
- 1 表示句子的开始;
- 2 表示未知单词,因为我们规定只使用 10000 个最常用的单词;
- 3 表示未使用的单词。
- 在划分验证集合的时候,我们按照 50% 的比例划分训练集合与验证集合;
- 在模型的第二层,我们采用了一维全局池化,该层没有可训练的参数,该层是为了降低训练所需要数据量,输出是一个固定长度的向量;
- 模型的最后一层的激活函数为 “Sigmoid” ,这个激活函数将输出分为 0 或者 1 ,通常用于二分类的任务。
- 在编译过程之中我们采用了**“二元交叉熵”(binary_crossentropy)**的损失函数,该损失函数通常用作二元分类问题
- 因为在数据处理过程中我们没有划分 Batch ,因此我们要在训练(fit)的过程之中来定义 Batch_Size。
4. 程序的结果
运行上面的程序,我们可以得到如下的输出:
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_4 (Embedding) (None, None, 32) 320000
_________________________________________________________________
global_average_pooling1d_3 ( (None, 32) 0
_________________________________________________________________
dense_8 (Dense) (None, 64) 2112
_________________________________________________________________
dense_9 (Dense) (None, 1) 65
=================================================================
Total params: 322,177
Trainable params: 322,177
Non-trainable params: 0
_________________________________________________________________
Epoch 1/30
196/196 [==============================] - 2s 10ms/step - loss: 0.6428 - accuracy: 0.6598 - val_loss: 0.5054 - val_accuracy: 0.8246
Epoch 2/30
196/196 [==============================] - 2s 10ms/step - loss: 0.3655 - accuracy: 0.8654 - val_loss: 0.3217 - val_accuracy: 0.8741
Epoch 3/30
196/196 [==============================] - 2s 10ms/step - loss: 0.2429 - accuracy: 0.9084 - val_loss: 0.2956 - val_accuracy: 0.8763
Epoch 4/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1869 - accuracy: 0.9322 - val_loss: 0.2870 - val_accuracy: 0.8842
Epoch 5/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1468 - accuracy: 0.9498 - val_loss: 0.2978 - val_accuracy: 0.8820
Epoch 6/30
196/196 [==============================] - 2s 10ms/step - loss: 0.1167 - accuracy: 0.9622 - val_loss: 0.3121 - val_accuracy: 0.8835
Epoch 7/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0915 - accuracy: 0.9737 - val_loss: 0.3375 - val_accuracy: 0.8786
Epoch 8/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0720 - accuracy: 0.9805 - val_loss: 0.3668 - val_accuracy: 0.8784
Epoch 9/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0558 - accuracy: 0.9870 - val_loss: 0.3917 - val_accuracy: 0.8747
Epoch 10/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0437 - accuracy: 0.9924 - val_loss: 0.4241 - val_accuracy: 0.8729
Epoch 11/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0345 - accuracy: 0.9946 - val_loss: 0.4539 - val_accuracy: 0.8696
Epoch 12/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0272 - accuracy: 0.9956 - val_loss: 0.4948 - val_accuracy: 0.8703
Epoch 13/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0201 - accuracy: 0.9974 - val_loss: 0.5199 - val_accuracy: 0.8679
Epoch 14/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0150 - accuracy: 0.9984 - val_loss: 0.5517 - val_accuracy: 0.8662
Epoch 15/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0122 - accuracy: 0.9987 - val_loss: 0.5818 - val_accuracy: 0.8646
Epoch 16/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0098 - accuracy: 0.9991 - val_loss: 0.6114 - val_accuracy: 0.8642
Epoch 17/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0080 - accuracy: 0.9993 - val_loss: 0.6514 - val_accuracy: 0.8632
Epoch 18/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0063 - accuracy: 0.9996 - val_loss: 0.6680 - val_accuracy: 0.8621
Epoch 19/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0047 - accuracy: 0.9997 - val_loss: 0.6967 - val_accuracy: 0.8620
Epoch 20/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0039 - accuracy: 0.9998 - val_loss: 0.7308 - val_accuracy: 0.8611
Epoch 21/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0027 - accuracy: 1.0000 - val_loss: 0.7511 - val_accuracy: 0.8608
Epoch 22/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0023 - accuracy: 0.9999 - val_loss: 0.7780 - val_accuracy: 0.8601
Epoch 23/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.8057 - val_accuracy: 0.8590
Epoch 24/30
196/196 [==============================] - 2s 10ms/step - loss: 0.0016 - accuracy: 0.9999 - val_loss: 0.8214 - val_accuracy: 0.8606
Epoch 25/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.8376 - val_accuracy: 0.8602
Epoch 26/30
196/196 [==============================] - 2s 11ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.8689 - val_accuracy: 0.8592
Epoch 27/30
196/196 [==============================] - 2s 12ms/step - loss: 8.3966e-04 - accuracy: 1.0000 - val_loss: 0.8716 - val_accuracy: 0.8592
Epoch 28/30
196/196 [==============================] - 2s 10ms/step - loss: 7.2445e-04 - accuracy: 1.0000 - val_loss: 0.8918 - val_accuracy: 0.8588
Epoch 29/30
196/196 [==============================] - 2s 12ms/step - loss: 6.1936e-04 - accuracy: 1.0000 - val_loss: 0.9143 - val_accuracy: 0.8591
Epoch 30/30
196/196 [==============================] - 2s 10ms/step - loss: 5.2330e-04 - accuracy: 1.0000 - val_loss: 0.9336 - val_accuracy: 0.8596
782/782 [==============================] - 1s 2ms/step - loss: 0.9893 - accuracy: 0.8468
[0.9892528653144836, 0.8467599749565125]
由此可以看到,我们的网络最终在测试集合上达到了 84.68% 的准确率,同时它的损失为 0.9893 。
5. 小结
在这节课之中,我们学会了如何在机器学习之中处理文本数据,同时了解了对文本进行分类的基本步骤。
通过自己的动手实现,我们实现了一个分类准确率接近 85% 的文本分类器。