ぷろぐ((>ω<))

ぷろぐらみんぐ関係のメモ

Google Colab 上で Keras 入門(SegNet-Basic を実装・学習・推論)

4年前ぐらいに Deep Learning 入門をしてからしばらく触っていなかったが、コロナ外出自粛で時間が余っていることもあって久しぶりにトライ。
Google Colab でやる手順をメモ。試したのはSegNetの軽量版のSegNet-Basic。

  1. SegNet: 画像セグメンテーションニューラルネットワーク - Qiita
  2. GitHub - 0bserver07/Keras-SegNet-Basic: SegNet-Basic with Keras

(1)でSegNet-Basicの情報源として(2)にリンクを貼っているが、(2)のHow-to が書きかけだったので、(2)のページを参考にしながら下の手順でKerasで実装したところ、うまくいった。

ちなみに、何故SegNetではなくてSegNet-Basicかというと、実はSegNetも試してみたが、Google Colabではリソースの制約上クラッシュして仕方がなかったので諦めただけ。

環境確認

以下の記事は、下記の環境で実行した。

import tensorflow as tf
import keras

print("tf.__version__ is", tf.__version__)
print("tf.keras.__version__ is ", tf.keras.__version__)
print("keras.__version__ is ", keras.__version__)
Using TensorFlow backend.
tf.__version__ is 2.2.0-rc4
tf.keras.__version__ is  2.3.0-tf
keras.__version__ is  2.3.1

下準備

ノートブックを新規作成

  • 「ファイル」 > 「ノートブックを新規作成」
  • ここではSegNetBasic.ipynbとした

Google Drive マウント

  1. Google DriveGoogle Colab にマウント
    • 左端の「ファイル」から操作できる
    • ↓こういう状態になるはず

f:id:presan:20200508144101p:plain
Google Drive をマウント

  1. 下記で公開されているSegNetリポジトリ内のデータ(CamVidフォルダ以下)をGoogle Drive にコピー。
    • 今回は drive/My Drive/Colab Notebooks/にSegNetというフォルダを作り、その中にCamVidをコピーした。
  2. SegNetフォルダ下に次のフォルダを作成
    • data
    • models
    • weights
    • valid  (おまけ)

モデル準備

SegNetフォルダ下に下記をmodel.pyとして保存。

import keras.models as models
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Layer, Dense, Dropout, Activation, Flatten, Reshape, Permute
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
from keras.callbacks import ModelCheckpoint

def buildSegnetBasicModel(input_shape, n_labels, kernel=3, pool_size=(2, 2), pad=1, output_mode="softmax"):
    # encoder
    inputs = Input(shape=input_shape)

    conv_1 = ZeroPadding2D(padding=(pad,pad))(inputs)
    conv_1 = Convolution2D(64, (kernel, kernel), padding="valid")(conv_1)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = Activation("relu")(conv_1)

    pool_1 = MaxPooling2D(pool_size)(conv_1)

    conv_2 = ZeroPadding2D(padding=(pad,pad))(pool_1)
    conv_2 = Convolution2D(128, (kernel, kernel), padding="valid")(conv_2)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = Activation("relu")(conv_2)

    pool_2 = MaxPooling2D(pool_size)(conv_2)

    conv_3 = ZeroPadding2D(padding=(pad,pad))(pool_2)
    conv_3 = Convolution2D(256, (kernel, kernel), padding="valid")(conv_3)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = Activation("relu")(conv_3)
    
    pool_3 = MaxPooling2D(pool_size)(conv_3)

    conv_4 = ZeroPadding2D(padding=(pad,pad))(pool_3)
    conv_4 = Convolution2D(512, (kernel, kernel), padding="valid")(conv_4)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = Activation("relu")(conv_4)

    print("Build SegNet-Basic enceder done..")

    # decoder
    conv_5 = ZeroPadding2D(padding=(pad,pad))(conv_4)
    conv_5 = Convolution2D(512, (kernel, kernel), padding="valid")(conv_5)
    conv_5 = BatchNormalization()(conv_5)

    unpool_1 = UpSampling2D(pool_size)(conv_5)

    conv_6 = ZeroPadding2D(padding=(pad,pad))(unpool_1)
    conv_6 = Convolution2D(256, (kernel, kernel), padding="valid")(conv_6)
    conv_6 = BatchNormalization()(conv_6)
    
    unpool_2 = UpSampling2D(pool_size)(conv_6)

    conv_7 = ZeroPadding2D(padding=(pad,pad))(unpool_2)
    conv_7 = Convolution2D(128, (kernel, kernel), padding="valid")(conv_7)
    conv_7 = BatchNormalization()(conv_7)

    unpool_3 = UpSampling2D(pool_size)(conv_7)

    conv_8 = ZeroPadding2D(padding=(pad,pad))(unpool_3)
    conv_8 = Convolution2D(64, (kernel, kernel), padding="valid")(conv_8)
    conv_8 = BatchNormalization()(conv_8)

    conv_9 = Convolution2D(n_labels, (1, 1), padding="valid")(conv_8)
    conv_9 = Reshape(
        (input_shape[0] * input_shape[1], n_labels),
        input_shape=(input_shape[0], input_shape[1], n_labels),
    )(conv_9)

    outputs = Activation(output_mode)(conv_9)
    
    print("Build SegNet-Basic decoder done..")

    model = Model(inputs=inputs, outputs=outputs, name="SegNetBasic")

    return model

データ準備

次を実行して、画像データをnumpyバイナリ形式(.npy)に変換。SegNet/data内にnpyファイルが6個、合計4.5GB程度生成されているはず。

import cv2
import numpy as np

from helper import *

import os
import gc

# Copy the data to this dir here in the SegNet project /CamVid from here:
# https://github.com/alexgkendall/SegNet-Tutorial
RootPath = 'drive/My Drive/Colab Notebooks/SegNet'
DataPath = 'drive/My Drive/Colab Notebooks/SegNet/CamVid/'
data_shape = 360*480

def normalized(rgb):
    return rgb / 255.

def one_hot_it(labels):
    w, h = labels.shape[:2]
    x = np.zeros([w,h,12], dtype=np.uint8)
    for i in range(w):
        for j in range(h):
            x[i,j,labels[i][j]]=1
    return x

def load_data(mode):
    data = []
    label = []
    with open(DataPath + mode +'.txt') as f:
        txt = f.readlines()
        txt = [line.split(' ') for line in txt]
    for i in range(len(txt)):
        datapath = RootPath + txt[i][0][7:]
        print('(', i, '/', len(txt), ') Loading data: ', datapath)
        img = cv2.imread(datapath)
        data.append(np.rollaxis(normalized(img),2))

        labelpath = RootPath + txt[i][1][7:][:-1]
        print('(', i, '/', len(txt), ') Loading label: ', labelpath)
        img = cv2.imread(labelpath)
        label.append(one_hot_it(img[:,:,0]))
    return np.array(data), np.array(label)



train_data, train_label = load_data("train")
train_label = np.reshape(train_label,(367,data_shape,12))
np.save(RootPath + "/data/train_data", train_data)
np.save(RootPath + "/data/train_label", train_label)
del train_data
del train_label
gc.collect()

test_data, test_label = load_data("test")
test_label = np.reshape(test_label,(233,data_shape,12))
np.save(RootPath + "/data/test_data", test_data)
np.save(RootPath + "/data/test_label", test_label)
del test_data
del test_label
gc.collect()

val_data, val_label = load_data("val")
val_label = np.reshape(val_label,(101,data_shape,12))
np.save(RootPath + "/data/val_data", val_data)
np.save(RootPath + "/data/val_label", val_label)
del val_data
del val_label
gc.collect()

学習

モデルを作成

次を実行

# Reference:
# https://qiita.com/cyberailab/items/d11862852eccc17585e8
# https://github.com/0bserver07/Keras-SegNet-Basic

import keras.models as models
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Layer, Dense, Dropout, Activation, Flatten, Reshape, Permute
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSampling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint

import cv2
import numpy as np
import matplotlib.pyplot as plt
import time

# Import SegNet/SegNetBasic modules
import sys
RootPath = '/content/drive/My Drive/Colab Notebooks/SegNet/'
sys.path.append(RootPath)
from model import buildSegnetModel, buildSegnetBasicModel

# Start time
start_time = time.time()

# Fix seed for reproducibility
np.random.seed(0)

# Parameters
class_weighting= [0.2595, 0.1826, 4.5640, 0.1417, 0.5051, 0.3826, 9.6446, 1.8418, 6.6823, 6.2478, 3.0, 7.3614]
input_shape = (360, 480, 3)
n_labels =  12
kernel = 3
pool_size = 2
pad = 1
output_mode = 'softmax'
data_shape = input_shape[0] * input_shape[1]

# load the model:
print("Building model...")
model_segnet = buildSegnetBasicModel(input_shape, n_labels, kernel, pool_size, pad, output_mode)
print("done.")

print("Compiling model...")
model_segnet.compile(loss="categorical_crossentropy", optimizer='adadelta', metrics=["accuracy"])
print("done.")

# Visualize model
model_segnet.summary()

# Calculate erapsed time
model_load_time = time.time()
print('Erapsed time: ', (model_load_time - start_time), '[s]')

npyデータを読み込み

次を実行。5~10分程度時間がかかる。

# load the data
print("Loading data...")
print("- Loading train_data...")
train_data = np.load(RootPath + 'data/train_data.npy').transpose((0, 2, 3, 1)) # NCHW to NHWC
print("  -> shape", train_data.shape)
print("- Loading train_label...")
train_label = np.load(RootPath + 'data/train_label.npy')
print("  -> shape", train_label.shape)
print("- Loading test_data...")
test_data = np.load(RootPath + 'data/test_data.npy').transpose((0, 2, 3, 1)) # NCHW to NHWC
print("  -> shape", test_data.shape)
print("- Loading test_label...")
test_label = np.load(RootPath + 'data/test_label.npy')
print("  -> shape", test_label.shape)
print("done.")

# Calculate erapsed time
data_load_time = time.time()
print('Erapsed time: ', (data_load_time - model_load_time), '[s]')

学習開始

ランタイム選択

  • 「ランタイム」 > 「ランタイムのタイプを変更」 で GPU or TPU を選ぶ

学習開始

次を実行。

  • 途中、標準割り当てメモリ(12.5[GB])ではメモリ不足でクラッシュするかも。
    • その場合は2倍の25[GB]をGoogleから有難く割り当ててもらって再実行。
    • 25[GB]あれば大丈夫
  • そのときの割り当てリソースにもよるが、ある日のGPU選択時は1epoch 110[s]程度だった。
    • つまり100epochの学習に2[h]程度
    • batch_sizeを2倍の12にすると1epoch 30[s]で100epoch成功するときもあれば、ResourceExhaustedErrorでクラッシュすることもあった
  • 過去最高の結果になるたび、SegNet/weightsフォルダに重みを上書き保存。(checkpoint)
# Parameter
nb_epoch = 100
batch_size = 6

# checkpoint
print("Deifining callbacks...")
filepath = RootPath + "weights/segnet_weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
print("done.")

# Calculate erapsed time
def_cb_time = time.time()
print('Erapsed time: ', (def_cb_time - data_load_time), '[s]')
print('-----------------')

# Fit the model
print("Fitting model...")
hist = model_segnet.fit(train_data, train_label, callbacks=callbacks_list, batch_size=batch_size, epochs=nb_epoch,
                    verbose=2, class_weight=class_weighting, validation_data=(test_data, test_label), shuffle=True) # validation_split=0.33
print("done.")

# Calculate erapsed time
fit_model_time = time.time()
print('Erapsed time: ', (fit_model_time - def_cb_time), '[s]')
print('-----------------')

モデル・重み保存

  • 次を実行して、推論時に使うためにモデルと100epoch目の重みをファイルに保存。SegNetフォルダ下のmodels, weightsフォルダにそれぞれ保存。
  • ただ、重みは100epoch目よりも、過去最高(checkpoint)の結果を使うべき
    • checkpointと100epoch目の重みファイルのサイズがかなり違うのは何故だろう?
# This save the trained model weights to this file with number of epochs
print("Saving model and weights...")
model_segnet.save(RootPath + 'models/segnet_model.hdf5')
model_segnet.save_weights(RootPath + 'weights/segnet_weight_{}.hdf5'.format(nb_epoch))
print("done.")

# Calculate erapsed time
fit_model_time = time.time()
print('Erapsed time: ', (fit_model_time - data_load_time), '[s]')
print('-----------------')

** loss/accuracyの可視化
次を実行してloss/accuracyの推移をグラフ描画。

>|python|
# Visualize
epochs = range(1, len(hist.history['accuracy']) + 1)

plt.plot(epochs, hist.history['loss'], label='Training loss', ls='-')
plt.plot(epochs, hist.history['val_loss'], label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.plot(epochs, hist.history['accuracy'],  label='Training acc')
plt.plot(epochs, hist.history['val_accuracy'], label="Validation acc")
plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
学習結果

validationの精度がmaxで約80%ぐらいになった

f:id:presan:20200508141014p:plain
Loss
f:id:presan:20200508141032p:plain
Accuracy

推論

モデルと重みを読み込み

次を実行

from keras.models import load_model
from google.colab import files
from PIL import Image
import cv2
import numpy as np
import time

RootPath = '/content/drive/My Drive/Colab Notebooks/SegNet/'

# Parameter
input_shape = (360, 480, 3)

# Start time
start_time = time.time()

# Load model and weights
print('Loading model and weights...')
model_segnet = load_model(RootPath + 'models/segnet_model.hdf5')
model_segnet.load_weights(RootPath + 'weights/segnet_weights.best.hdf5')
print('done')

# Calculate erapsed time
model_load_time = time.time()
print('Erapsed time: ', (model_load_time - start_time), '[s]')

検証用データを読み込み

次を実行

val_data = np.load(RootPath + '/data/val_data.npy').transpose((0, 2, 3, 1))  # NCHW to NHWC
val_label = np.load(RootPath + '/data/val_label.npy')
print('done')

評価

次を実行

batch_size = 12

# estimate accuracy on whole dataset using loaded weights
scores = model_segnet.evaluate(val_data, val_label, verbose=0, batch_size=batch_size)
print("%s: %.2f%%" % (model_segnet.metrics_names[1], scores[1]*100))
評価結果
accuracy: 87.45%

可視化

次を実行。
検証用データ全部にやっては大変なので、最大10個まで可視化するようにした。

import matplotlib.pyplot as plt

Sky = [128,128,128]
Building = [128,0,0]
Pole = [192,192,128]
Road_marking = [255,69,0]
Road = [128,64,128]
Pavement = [60,40,222]
Tree = [128,128,0]
SignSymbol = [192,128,128]
Fence = [64,64,128]
Car = [64,0,128]
Pedestrian = [64,64,0]
Bicyclist = [0,128,192]
Unlabelled = [0,0,0]

label_colours = np.array([Sky, Building, Pole, Road, Pavement,
                          Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])

def visualize(temp, plot=True):
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0,11):
        r[temp==l]=label_colours[l,0]
        g[temp==l]=label_colours[l,1]
        b[temp==l]=label_colours[l,2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:,:,0] = (r/255.0)#[:,:,0]
    rgb[:,:,1] = (g/255.0)#[:,:,1]
    rgb[:,:,2] = (b/255.0)#[:,:,2]
    if plot:
        plt.imshow(rgb)
    else:
        return rgb

img_size = (input_shape[0], input_shape[1])

# Start time
start_time = time.time()

# Predict
print('Predicting...')
output = model_segnet.predict(val_data)
print('done')

# Calculate erapsed time
end_time = time.time()
print('Erapsed time: ', (end_time - start_time), '[s]')
print('-----------------')

# Visualize
count = min([10, len(output)])
for i in range(count):
  pred_class = np.argmax(output[i], axis=1).reshape(img_size)
  img_ret = visualize(pred_class, False)
  plt.figure(i * 2)
  plt.imshow(val_data[i])
  plt.figure(i * 2 + 1)
  plt.imshow(img_ret)
plt.show()
可視化結果

それっぽいsegmentation結果になった

f:id:presan:20200508142426p:plain
元画像
f:id:presan:20200508142440p:plain
SegNetBasic結果

(おまけ) 自前の画像で可視化

  • SegNet/validフォルダに480x360のRGB画像をtest0~8.pngという名前で保存
  • 次を実行
    • さっきまでSegNet公開リポジトリの検証用データを読み込んでいた変数val_dataに自前の画像を読み込む
val_data = []
for i in range(9):
  # Load image
  img_path = RootPath + 'valid/test{}.png'.format(i)
  print('Loading ', img_path)
  img = cv2.imread(img_path)
  if img is None:
    print('Failed to load ', img_path)
    continue
  #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  print('done')
  val_data.append(img)

val_data = np.array(val_data)
可視化結果

うーん。。

f:id:presan:20200508143222p:plain
元画像
f:id:presan:20200508143234p:plain
SegNetBasic結果