okpy

Pythonエンジニア兼テックリーダーが、多くのプロジェクトとチーム運営から得た実践的な知識を共有するブログです。

画像処理特化モジュール:PyTorch torchvisionの機能と使い方

Python torchvision ライブラリ完全ガイド

Pythontorchvision は、PyTorch 環境での画像データの取り扱いを簡単にするためのライブラリです。画像データの前処理、データセットの読み込み、モデルの保存などが行えます。

1. torchvision ライブラリの概要

  • PyTorch の画像処理特化モジュールグループ
  • CIFAR10、MNIST、ImageNet などのデータセットを使用可能
  • 変換、機械学習の前処理ステップを提供

インストール方法

pip install torchvision

2. 主な機能と使用例

(1) データセットの読み込み (CIFAR10)

from torchvision import datasets, transforms

dataset = datasets.CIFAR10(root="./data", download=True, transform=transforms.ToTensor())

(2) 変換の定義

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(28),
    transforms.ToTensor()
])

(3) DataLoader の使用

from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=64, shuffle=True)

(4) データの表示

import matplotlib.pyplot as plt
import numpy as np

images, labels = next(iter(loader))
plt.imshow(np.transpose(images[0], (1, 2, 0)))
plt.title(f"Label: {labels[0]}")
plt.show()

(5) torchvision.models の利用 (ResNet18)

from torchvision import models

model = models.resnet18(pretrained=True)
print(model)

(6) transforms.RandomHorizontalFlip()

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

(7) torchvision.utils.make_grid()

from torchvision.utils import make_grid

grid_img = make_grid(images[:4])
plt.imshow(np.transpose(grid_img, (1, 2, 0)))
plt.show()

(8) torchvision.io 画像読み込み

from torchvision.io import read_image

img = read_image("sample.jpg")
print(img.shape)

(9) torchvision.transforms.Normalize

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

(10) torchvision.transforms.RandomRotation

transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.ToTensor()
])

3. torchvision の主なモジュール

機能 説明
datasets CIFAR10, MNIST などのデータ読み込み
transforms 画像変換処理 (トリミング、抽出、Tensor 化)
models ResNet などの先端機械学習モデル
utils make_grid, save_image などの補助ツール
io 画像ファイルの読み込みや保存

まとめ

torchvision は PyTorch の画像処理を支える強力なライブラリです。 データ読み込み、変換、モデル読み込みなどが容易になり、実践的なデータ分類モデル開発を助けます 🚀