画像認識の精度向上に欠かせない!PyTorchにおける変換(Transforms)の基礎知識
PyTorchにおける変換(Transforms)とは?
画像の読み込みと前処理
- 画像ファイルを読み込み、PyTorchテンソルに変換する
- 画像のサイズを変更する
- 画像を回転させる
- 画像の色空間を変換する
- 画像にノイズを追加する
データの標準化
- データの平均と標準偏差を計算し、データの各要素を標準化
- データのスケーリング
データ拡張
- 画像を回転させたり、反転させたりして、データセットを人工的に増やす
- 画像の一部を切り取ったり、ランダムに色を変えたりして、データセットを多様化する
バッチ化
- テキストデータの前処理
- 音声データの前処理
PyTorchで変換を使用する例
PyTorchでは、torchvision.transforms
モジュールに、画像処理やデータの前処理によく用いられる変換が多数用意されています。
import torchvision.transforms as transforms
# 画像を読み込み、テンソルに変換
img = Image.open("image.jpg")
tensor_img = transforms.ToTensor()(img)
# 画像をリサイズ
resized_img = transforms.Resize((224, 224))(tensor_img)
# 画像を回転
rotated_img = transforms.RandomRotation(10)(resized_img)
# 画像を標準化
normalized_img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(rotated_img)
import torchvision.transforms as transforms
# 画像を読み込み、テンソルに変換
img = Image.open("image.jpg")
tensor_img = transforms.ToTensor()(img)
# 画像をリサイズ
resized_img = transforms.Resize((224, 224))(tensor_img)
# 画像を回転
rotated_img = transforms.RandomRotation(10)(resized_img)
# 画像を標準化
normalized_img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(rotated_img)
# 画像を表示
import matplotlib.pyplot as plt
plt.imshow(normalized_img.permute(1, 2, 0))
plt.show()
import numpy as np
# データの平均と標準偏差を計算
data = np.array([1, 2, 3, 4, 5])
mean = np.mean(data)
std = np.std(data)
# データを標準化
normalized_data = (data - mean) / std
# 標準化されたデータを表示
print(normalized_data)
データ拡張
import torchvision.transforms as transforms
# 画像を読み込み、テンソルに変換
img = Image.open("image.jpg")
tensor_img = transforms.ToTensor()(img)
# 画像をランダムに切り取る
random_crop_img = transforms.RandomCrop((224, 224))(tensor_img)
# 画像をランダムに回転させる
random_rotation_img = transforms.RandomRotation(10)(tensor_img)
# 画像をランダムに反転させる
random_flip_img = transforms.RandomHorizontalFlip()(tensor_img)
# 画像を表示
import matplotlib.pyplot as plt
plt.subplot(141)
plt.imshow(tensor_img.permute(1, 2, 0))
plt.title("Original")
plt.subplot(142)
plt.imshow(random_crop_img.permute(1, 2, 0))
plt.title("Random Crop")
plt.subplot(143)
plt.imshow(random_rotation_img.permute(1, 2, 0))
plt.title("Random Rotation")
plt.subplot(144)
plt.imshow(random_flip_img.permute(1, 2, 0))
plt.title("Random Flip")
plt.show()
自作の変換
torch.nn.Module
を継承したクラスを作成するforward
メソッドを実装する- メソッド内で、データの変換処理を行う
import torch
class MyTransform(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, data):
# データの変換処理
return transformed_data
# 自作の変換を使用する
transform = MyTransform()
transformed_data = transform(data)
カスタム変換ライブラリの利用
torchvision.transforms
モジュール以外にも、画像処理やデータの前処理に特化したカスタム変換ライブラリが多数公開されています。以下に、代表的なライブラリをいくつか紹介します。
これらのライブラリを使用することで、より高度な変換処理を行うことができます。
image input transformation