跳到主要内容

在 TensorFlow 之中自定义训练

在前面两节的学习之中,我们学习了如何进行自定义微分操作,又学习了如何自定义模型,那么接下来这一小节我们便来学习如何进行自定义的最后一步 —— 自定义训练。

在之前的学习之中,当我们进行训练的时候,我们采用的都是 fit () 函数。虽然说 fit () 函数给我们提供了很多的选项,但是如果想要更加深入的定制我们的寻来你过程,那么我们便需要自己编写训练循环。

在这节课之中,我们会采用一个对 mnist 数据集进行分类的简单模型作为一个简单的示例来进行演示,以此来帮助大家理解自定义训练的过程。因此该课程主要可以分为以下几个部分:

  • 自定义模型
  • 编写自定义循环
  • 如何在自定义循环中进行模型的优化工作。

1. 数据的准备工作

首先我们要准备好相应的 mnist 数据集,这里采用往常的处理方式:使用内置的 API 来获取数据集:

import tensorflow as tf

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# 数据归一化
train_images = train_images / 255.0
test_images = test_images / 255.0


train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

valid_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
valid_dataset = valid_dataset.batch(64)

在这里,不仅仅构建了数据集,我们同样对图片数据进行了归一化的操作。同时,我们对数据进行了分批次的处理,批次的大小维 64 。于此同时,我们对训练数据进行了乱序处理。

2. 自定义模型

由于进行 Mnist 图像分类的任务比较简单,因此我们可以定义一个较为简单的模型,这里的模型的结构包含四层:

  • Flattern 层:对二维数据进行展开;
  • 第一个 Dense 层:包含 128 个神经元;
  • 第二个 Dense 层:包含 64 个神经元;
  • 最后一个 Dense 分类层;包含 10 个神经元,对应于我们的十个分类。
class MyModel(tf.keras.Model):
def \_\_init\_\_(self):
super(MyModel, self).__init__()
self.l1 = tf.keras.layers.Flatten()
self.l2 = tf.keras.layers.Dense(128, activation='relu')
self.l3 = tf.keras.layers.Dense(64, activation='relu')
self.l4 = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=True):
x = self.l1(inputs)
x = self.l2(x)
x = self.l3(x)
y = self.l4(x)
return y
model = MyModel()

3. 定义训练循环

在做好准备工作之后,我们便来到了我们的最重要的部分,也就是如何进行自定义循环的构建。

在自定义循环之前,我们要先做好准备工作,分为如下几步:

  • 自定义损失函数:在大多数情况之下,内置的损失函数以及足够我们使用,比如交叉熵等损失函数;
  • 自定义优化器:优化器决定了按照如何的策略进行优化,我们最常用的优化器就是 Adam ,因此这里我们使用内置的 Adam 优化器;
  • (可选)定义变量监视器:用于监视我们的训练过程的各种参数,在这里我们只使用一个来监视我们的验证集合上的效果。

因此我们的代码可以如下所示:

# 损失函数
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
# 优化器
optimizer = tf.keras.optimizers.Adam()
# 监控验证机上的准确率
val_acc = tf.keras.metrics.SparseCategoricalAccuracy()

然后我们便可以构建自定义循环,自定义循环大致分为以下几步:

  • 编写一个循环 Epoch 次的循环,Epoch 为训练的循环数;
  • 在循环内部对于数据集读取每一个 Batch,因为这里的 train_dataset 为可枚举的,因此我们直接使用枚举即可获得每一个批次的训练样本;
  • 定义 tf.GradientTape () 梯度带;
  • 在梯度带内进行模型的输出,以及损失的求取
  • 梯度带外使用梯度带求得模型所有参数的梯度,在这里我们可以使用 model.trainable_weights 来获取所有可训练的参数;
  • 使用优化器按照求得的梯度对模型的参数进行优化,这里直接使用 optimizer.apply_gradients 函数即可完成优化;
  • (可选)进行 Log 处理,打印出日志便于我们查看;
  • (可选)在每个 Epoch 的训练集的训练结束后,我们可以在测试集上查看结果,这里我们只查看了准确率。
epochs = 3
for epoch in range(epochs):
print("Start Training epoch " + str(epoch))

# 取出每一个批次的数据
for batch_i, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# 在梯度带内进行操作
with tf.GradientTape() as tape:
outputs = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, outputs)

# 求取梯度
grads = tape.gradient(loss_value, model.trainable_weights)
# 使用Optimizer进行优化
optimizer.apply_gradients(zip(grads, model.trainable_weights))

# Log
if batch_i % 100 == 0:
print("Loss at batch %d: %.4f" % (batch_i, float(loss_value)))

# 在验证集合上测试
for batch_i, (x_batch_train, y_batch_train) in enumerate(valid_dataset):
outputs = model(x_batch_train, training=False)
# 更新追踪器的状态
val_acc.update_state(y_batch_train, outputs)
print("Validation acc: %.4f" % (float(val_acc.result()),))

# 重置追踪器
val_acc.reset_states()

最终,我们可以得到如下结果:

Start Training epoch 0
Loss at batch 0: 0.1494
Loss at batch 100: 0.2155
Loss at batch 200: 0.1080
Loss at batch 300: 0.0231
Loss at batch 400: 0.1955
Loss at batch 500: 0.2019
Loss at batch 600: 0.0567
Loss at batch 700: 0.1099
Loss at batch 800: 0.0714
Loss at batch 900: 0.0364
Validation acc: 0.9691
Start Training epoch 1
Loss at batch 0: 0.0702
Loss at batch 100: 0.0615
Loss at batch 200: 0.0208
Loss at batch 300: 0.0158
Loss at batch 400: 0.0304
Loss at batch 500: 0.1193
Loss at batch 600: 0.0130
Loss at batch 700: 0.1353
Loss at batch 800: 0.1300
Loss at batch 900: 0.0056
Validation acc: 0.9715
Start Training epoch 2
Loss at batch 0: 0.0714
Loss at batch 100: 0.0066
Loss at batch 200: 0.0177
Loss at batch 300: 0.0086
Loss at batch 400: 0.0099
Loss at batch 500: 0.1621
Loss at batch 600: 0.1103
Loss at batch 700: 0.0049
Loss at batch 800: 0.0139
Loss at batch 900: 0.0111
Validation acc: 0.9754

大家可以发现,我们的模型在测试集合上达到了 97.54% 的准确率。

同时我们可以发现,其实在第一个 Epoch 之后,模型已经达到了很好地效果,这是因为我们的任务比较简单,而且我们的模型拟合能力比较强。

4. 小结

这节课我们采用案例驱动的方式,在对图片进行分类的过程之中学会了如何自定义循环,以及如何在自定义循环的时候进行优化。