이번 튜토리얼에서는 modified U-Net를 사용하여 이미지 세그멘테이션(segmentation) 작업에 대해 알아보겠습니다.
!pip install git+https://github.com/tensorflow/examples.git
import warnings
warnings.simplefilter('ignore')
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
지금까지 네트워크의 작업은 입력 이미지에 레이블 또는 클래스를 할당하는 이미지 분류를 살펴보았습니다. 그러나 이미지에서 객체가 어디에 있는지, 해당 객체의 모양을 알고싶거나 어떤 픽셀이 어떤 객체에 속하는지 등을 알고 싶다고 가정해봅시다.
이런 경우에는 이미지의 각 픽셀에 레이블을 지정하는 이미지 세그멘테이션을 합니다. 이미지 세그멘테이션 작업은 이미지의 픽셀을 기준으로 만든 마스크를 출력하도록 뉴럴 네트워크를 학습시키는 것입니다. 이렇게 하면 이미지를 훨씬 낮은 수준, 즉 픽셀 수준에서 이해하는 데 도움이 됩니다. 이미지 세그멘테이션은 의료 이미징(imaging), 자율 주행 차량과 위성 이미징 등 많은 애플리케이션을 제공합니다.
이번 튜토리얼에서 사용할 데이터셋은 Parkhi et al에서 만든 Oxford-IIIT Pet Dataset입니다. 데이터셋는 이미지와 이에 해당 레이블, 그리고 픽셀로 만든 마스크로 구성됩니다. 마스크는 기본적으로 각 픽셀에 대한 레이블입니다. 각 픽셀에는 다음 세 가지 범주 중 하나가 지정됩니다.
데이터셋은 이미 텐서플로우 데이터셋에 포함되어 있으며, 이를 다운로드하기만 하면 됩니다. 세그멘테이션 마스크는 3.0 이상의 버전에서 이용할 수 있습니다.
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
다음 코드는 이미지를 플립(flip)하는 간단한 어그멘테이션(augmentation)을 실행합니다. 또한 이미지는 [0,1]로 정규화됩니다. 마지막으로, 위에서 언급한 바와 같이 세그멘테이션 마스크의 픽셀에는 {1, 2, 3} 라벨이 지정됩니다. 편의상 분할 마스크에서 1을 빼서 레이블을 만듭니다. {0, 1, 2}.
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
@tf.function
def load_image_train(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
def load_image_test(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
데이터셋에는 이미 테스트와 학습 데이터가 나누어져 있으므로 이렇게 나눠진 데이터를 계속 사용하겠습니다.
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)
이제 예시 이미지와 그에 해당하는 마스크를 살펴보겠습니다.
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
if display_list[i].shape[2]>1:
plt.imshow(display_list[i])
else:
plt.imshow(np.array(display_list[i]).reshape(128,128))
plt.axis('off')
plt.show()
for image, mask in train.take(1):
sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
사용된 모델은 modified U-Net입니다. U-Net은 인코더(다운샘플러)와 디코더(업샘플러)로 구성됩니다. 로버스트(robust)한 피쳐를 학습하고 학습 가능한 매개 변수의 수를 줄이기 위해 학습된 모델을 인코더로 사용하겠습니다. 따라서 이 작업의 인코더는 학습된 MobileNetV2 모델로서 중간 출력값이 사용되며, 디코더는 Pix2pix에서 이미 구현된 업샘플 블록을 사용할 것입니다.
세 개의 채널이 출력 값인 이유는 각 픽셀에 세 개의 레이블이 있을 수 있기 때문입니다. 각 픽셀이 세 가지 클래스로 분류되는 다중 분류라고 생각하면됩니다.
OUTPUT_CHANNELS = 3
앞서 언급한 바와 같이, 인코더는 tf.keras.applications에 있는 훈련된 MobileNetV2 모델을 사용할 것입니다. 인코더는 모델의 중간 레이어의 특정 출력 값으로 구성됩니다. 인코더는 학습 과정 중에 학습하지 않습니다.
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# 다음과 같은 레이어 활성화 함수를 사용합니다.
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# 피쳐를 추출하는 모델을 생성합니다.
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
디코더/업샘플러는 pix2pix에서 구현된 업샘플 블록을 사용합니다.
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]
def unet_model(output_channels):
# 모델의 마지막 레이어입니다.
last = tf.keras.layers.Conv2DTranspose(
output_channels, 3, strides=2,
padding='same', activation='softmax') #64x64 -> 128x128
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
x = inputs
# 모델을 통해 다운샘플링합니다.
skips = down_stack(x)
x = skips[-1]
skips = reversed(skips[:-1])
# 업샘플링과 스킵 커넥션(skip conection)을 합니다.
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
이제, 이제 남은 것은 모델을 컴파일하고 학습시키는 것입니다. 여기서 사용되고 있는 손실(loss) 함수는 loss.sparse_categorical_crossentropy
입니다. 이 손실 함수를 사용하는 이유는 네트워크가 다중 클래스를 예측하는 것처럼 각 픽셀에도 레이블을 지정하려고 하기 때문입니다. 실제 세그멘테이션 마스크에서의 각 픽셀은 {0,1,2}의 값을 가지고 있습니다. 여기 있는 네트워크는 세 개의 채널을 출력하고 있습니다. 기본적으로 각 채널은 클래스 예측 기법을 학습하고 loss.sparse_categorical_crossentropy
는 이러한 상황에 잘 맞는 손실 함수입니다. 픽셀에 할당된 레이블은 네트워크의 출력 값이 가장 높은 채널입니다. 이것이 create_mask 함수의 기능입니다.
model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
모델의 결과 구조를 빠르게 확인해봅시다:
tf.keras.utils.plot_model(model, show_shapes=True)
이 모델을 사용해보고 학습하기 전에 예측한 것을 살펴보겠습니다.
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0], mask[0], create_mask(pred_mask)])
else:
display([sample_image, sample_mask,
create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()
학습하는 동안 모델이 어떻게 개선되는지 보겠습니다. 이 작업을 수행하기 위한 콜백(callback) 함수가 아래에 정의되어 있습니다.
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 5
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
model_history = model.fit(train_dataset, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=test_dataset,
callbacks=[DisplayCallback()])
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
epochs = range(EPOCHS)
plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss', color = 'b')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()
몇 가지 데이터를 예측해 봅시다. 시간을 절약하기 위해, 에포크 수는 작게 설정되었지만, 보다 정확한 결과를 얻기 위해 이 값을 더 높게 설정해야할 것입니다.
show_predictions(test_dataset, 3)
Licensed under the Apache License, Version 2.0 (the “License”);
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.