okpy

Pythonエンジニア兼テックリーダーが、多くのプロジェクトとチーム運営から得た実践的な知識を共有するブログです。

転移学習: PythonとKerasを活用した画像分類モデル構築

🧠 Python + Kerasで画像分類モデルを構築する【VGG16 転移学習】

今回は、KerasVGG16の事前学習済みモデルを使って、独自の画像データを分類するためのニューラルネットワークを構築する方法を紹介します。データの読み込みからモデル保存まで、完全なワークフローを体験できます。


🎯 このスクリプトでできること

  • list.txtに指定したディレクトリの各画像を読み込み
  • VGG16ベースのモデルに転移学習を適用
  • 学習済みモデルを .h5 形式で保存

🧱 使用ライブラリと環境構築

import numpy as np
import tensorflow as tf
import random as rn
import os
from tensorflow.compat.v1.keras import backend as K
  • 乱数シードを固定し、再現性のある学習結果を得られるように設定しています。

📄 入力ファイル構成とデータ形式

  • list.txt:画像フォルダのパスが1行ごとに記述されているファイルです。

    例: data/cat data/dog data/bird

  • 各フォルダには分類対象の画像が入っています。


🧠 モデル構築ステップ(転移学習)

1. 画像の前処理とラベル付け

from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
...
x = []
y = []
for count in range(len(lines)):
    for f in os.listdir(lines[count]):
        x.append(image.img_to_array(image.load_img(lines[count]+'/'+f, target_size=input_shape[:2])))
        y.append(count)
  • 各フォルダの画像を読み込み、(224, 224)にリサイズして配列に変換。
  • フォルダの順番に応じてラベル(0, 1, 2...)を付与。

2. データの正規化と分割

x = preprocess_input(x)
y = keras.utils.to_categorical(y, num_classes)
x_train, x_test, y_train, y_test = train_test_split(...)
  • VGG16に適した前処理を適用(平均値除去など)。
  • one-hotエンコーディングでラベル変換。
  • 訓練データとテストデータに分割。

3. モデル定義(VGG16 + カスタム出力層)

from keras.applications.vgg16 import VGG16
...
base_model = VGG16(weights='imagenet', include_top=False)
...
model = Model(inputs=base_model.input, outputs=predictions)
  • VGG16のトップレイヤー(分類部分)を除去し、新たに以下を追加:

    • GlobalAveragePooling2D
    • Dense(1024)
    • Dense(num_classes, softmax)
  • 転移学習として、VGG16の層はすべて固定(学習させない)。


4. コンパイルと学習

model.compile(loss=keras.losses.categorical_crossentropy, optimizer="rmsprop", metrics=['accuracy'])
history = model.fit(...)
  • 損失関数:categorical_crossentropy
  • オプティマイザ:rmsprop
  • バッチサイズ:128
  • エポック数:12

5. モデル保存

model.save('keras_model.h5')
  • 学習済みモデルを .h5 形式で保存し、後でロードして再利用可能にします。

💡 コード全文(再掲)

import numpy as np
import tensorflow as tf
import random as rn
import os
from tensorflow.compat.v1.keras import backend as K
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(0)
rn.seed(0)
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
tf.random.set_seed(0)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
K.set_session(sess)

with open('list.txt') as f:
    lines = f.read().splitlines()

from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
from sklearn.model_selection import train_test_split
import keras
import numpy as np
import os
input_shape = (224, 224, 3)
batch_size = 128
epochs = 12
num_classes = len(lines)
x = []
y = []
for count in range(len(lines)):
    for f in os.listdir(lines[count]):
        x.append(image.img_to_array(image.load_img(lines[count]+'/'+f, target_size=input_shape[:2])))
        y.append(count)
x = np.asarray(x)
x = preprocess_input(x)
y = np.asarray(y)
y = keras.utils.to_categorical(y, num_classes)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state= 3)

from keras.models import Sequential, Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.vgg16 import VGG16
base_model = VGG16(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
    layer.trainable = False
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer="rmsprop",
              metrics=['accuracy'])

history = model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))

model.save('keras_model.h5')

🎉 まとめ

このコードは、独自データセットを使った画像分類モデルの構築と学習のベースラインとして非常に有効です。以下のような拡張も可能です:

  • Data Augmentation(データ水増し)を加えて精度向上
  • fine-tuningでVGG16の一部レイヤーを学習可能にする
  • 精度・損失のグラフ可視化(matplotlib利用)