跳到主要内容

在 TensorFlow 之中进行图像分割

在之前的学习之中,对于图像数据,我们进行过分类等一些常见的任务;这节课我们便来学习一下对于图像数据的另外一种任务:图像分割。

1. 什么是图像分割

图像分割,顾名思义,就是对图像数据进行分割,而分类的物体一般是我们认为进行指定的。比如物品分割、人脸分割、医学病灶分割等。

举个例子,如下图所示,原来的图像是一个马路的图片,通过图像分割,我们会按照不同的物体进行不同的分割,比如车分为一类、人分为一类、建筑分为一类、马路分为一类等。

图片描述

图像分割是很多任务的前提,有很多的任务只有进行了有效的分割之后才能进行有效的处理,比如:

  • 医学病灶识别;
  • 人脸情绪识别;
  • 路况检测;
  • 自动驾驶;
  • 等等。

2. 如何进行图像分割

图像分割看上去是一个很复杂的任务,但是实现起来的原理却是非常简单,具体来说分为以下几步:

  • 确定要分类的类别,比如,我们可以将图片中所有的物体分割为 10 类,包括车、人等;
  • 对于每个像素点进行数字分类,数字分类的类别数量对应于上述的类别,这里是 10 ;
  • 将每个数字类别对应于分类的类别,比如 0 代表车、1 代表人。

可以看出,图像分割任务其实就是一个分类任务,只不过是对于每个像素点进行分类,也就是确定每个像素点所对应的类别。

在这节课之中,我们会使用图像分割的基础数据集:oxford_iiit_pet 图像分割数据集来进行演示。与此同时,我们也会采用之前学习到的迁移学习的方式来进行模型的构建,从而完成图像分割的任务。

3. 使用 TensorFlow 进行图像分割的程序示例

在 oxford_iiit_pet 之中,所有的图片都是宠物,我们的任务是将图片中的宠物分割出来,所有的像素点都被分为三类

  • 1: 对应于宠物的一部分;
  • 2: 对应于宠物的边界;
  • 3: 不属于宠物的一部分。

在这里,我们使用代码有一部分来自 TensorFlow 官方的一个例子,这个例子非常的简单易懂,作为图像分割任务的入门是再适合不过的了。

我们会逐步进行代码的解释与理解,从而帮助大家学习图像分割的任务的特点。

1. 首先我们获取数据集

import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

dataset, info = tfds.load('oxford\_iiit\_pet:3.\*.\*', with_info=True)

这里会下载数据集,因为是图片数据,因此数据集相对比较大。

2. 定义归一化处理函数

def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
return input_image, input_mask

它接收两个参数,第一个参数是图片,我们会将其归一化到 [0, 1] ,第二个参数是图像的标签。

3. 构建数据集

def load\_image\_train(data):
input_image = tf.image.resize(data['image'], (128, 128))
input_mask = tf.image.resize(data['segmentation\_mask'], (128, 128))

input_image, input_mask = normalize(input_image, input_mask)

return input_image, input_mask

def load\_image\_test(data):
input_image = tf.image.resize(data['image'], (128, 128))
input_mask = tf.image.resize(data['segmentation\_mask'], (128, 128))

input_image, input_mask = normalize(input_image, input_mask)

return input_image, input_mask

num_examples = info.splits['train'].num_examples
BATCH = 64
step_per_epch = num_examples // BATCH

train = dataset['train'].map(load_image_train)
test = dataset['test'].map(load_image_test)

train_dataset = train.cache().shuffle(1000).batch(BATCH).repeat()
test_dataset = test.batch(BATCH)

在构建数据集函数之中,我们做了两件事情:

  • 将图像与标签重新调整大小到 [128, 128]
  • 将数据归一化。

然后我们进行了分批的处理,这里取批次的大小为 64 ,大家可以根据自己的内存或现存大小灵活调整。

4. 构建网络模型


output_channels = 3

# 获取基础模型
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# 定义要使用其输出的基础模型网络层
layer_names = [
'block\_1\_expand\_relu', # 64x64
'block\_3\_expand\_relu', # 32x32
'block\_6\_expand\_relu', # 16x16
'block\_13\_expand\_relu', # 8x8
'block\_16\_project', # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]

# 创建特征提取模型
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)

down_stack.trainable = False

# 进行降频采样
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]

# 定义UNet网络模型
def unet\_model(output_channels):
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
x = inputs

# 在模型中降频取样
skips = down_stack(x)
x = skips[-1]
skips = reversed(skips[:-1])

# 升频取样然后建立跳跃连接
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])

# 这是模型的最后一层
last = tf.keras.layers.Conv2DTranspose(
output_channels, 3, strides=2,
padding='same') #64x64 -> 128x128

x = last(x)

return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(output_channels)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])


在这里,我们首先得到了一个预训练的 MobileNetV2 用于特征提取,在这里我们并没有包含它的输出层,因为我们要根据自己的任务灵活调节。

然后定义了我们要使用的 MobileNetV2 的网络层的输出,我们使用这些输出来作为我们提取的特征。

然后我们定义了我们的网络模型,这个模型的理解有些困难,大家可能不用详细了解网络的具体原理。大家只需要知道,这个网络大致经过的步骤包括:

  • 先将数据压缩(便于数据的处理)
  • 然后进行数据的处理
  • 最后将数据解压返回到原来的大小,从而完成网络的任务

最后我们编译该模型,我们使用 adam 优化器,交叉熵损失函数(因为图像分割是个分类任务)。

5. 模型的训练

epoch = 20
valid_steps = info.splits['test'].num_examples//BATCH

model_history = model.fit(train_dataset, epochs=epoch,
steps_per_epoch=step_per_epch,
validation_steps=valid_steps,
validation_data=test_dataset)

loss = model_history.history['loss']
val_loss = model_history.history['val\_loss']


这边就是一个简单的训练过程,我们可以得到如下输出:

Epoch 1/20
57/57 [==============================] - 296s 5s/step - loss: 0.4928 - accuracy: 0.7995 - val_loss: 0.6747 - val_accuracy: 0.7758
......
Epoch 20/20
57/57 [==============================] - 276s 5s/step - loss: 0.2586 - accuracy: 0.9218 - val_loss: 0.2821 - val_accuracy: 0.9148

我们可以看到我们最后达到了 91% 的准确率,还是一个可以接受的结果。

感兴趣的同学可以尝试一下进行结果的可视化,从而更加直观的查看到结果。

4. 小结

在这节课之中,我们学习了什么是图像分割,同时了解了图像分割的简单的实现方式,最终我们通过一个示例来了解了如何在 TensorFlow 之中进行图像分割。