PyTorchでCNN用のカスタム画像データセットを読み込む:DatasetFolder、カスタムデータローダー、HDF5ファイルの活用
PyTorchでCNN用のカスタム画像ベースデータセットを読み込む方法
このチュートリアルでは、PyTorchで畳み込みニューラルネットワーク (CNN) を使用するために、カスタム画像ベースデータセットをロードする方法を説明します。 画像分類、オブジェクト検出、セマンティックセグメンテーションなどのタスクを実行するために、CNNモデルをトレーニングするには、大量の画像データが必要です。 多くの場合、このデータは手動でラベル付けする必要があります。
必要なもの
このチュートリアルを完了するには、次のものが必要です。
- Python 3.x
- PyTorch
- カスタム画像データセット
カスタムデータセットの準備
カスタム画像データセットは、画像と対応するラベルのセットで構成されます。 画像は、JPEG または PNG などの一般的な画像形式で保存できます。 ラベルは、数値またはカテゴリカルラベルにすることができます。
データセットを整理する 1 つの方法は、次のようなディレクトリ構造を使用することです。
data_root/
├── train/
│ ├── class1/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ ├── class2/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ └── ...
├── val/
│ ├── class1/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ ├── class2/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ └── ...
└── test/
├── class1/
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
├── class2/
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
└── ...
この構造では、data_root
ディレクトリはデータセットのルートディレクトリであり、train
、val
、test
はそれぞれトレーニング、検証、テストデータセットに対応するサブディレクトリです。 各クラスのサブディレクトリには、対応するクラスの画像が含まれます。
PyTorchでデータセットをロードする
PyTorchには、torch.utils.data
モジュールに、カスタムデータセットをロードするためのツールが含まれています。 次のコードは、上記のカスタムデータセットをロードする方法を示しています。
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, root_dir, train=True):
self.root_dir = root_dir
self.train = train
# データセットを読み込む
if self.train:
data_path = os.path.join(root_dir, 'train')
else:
data_path = os.path.join(root_dir, 'val')
# 画像とラベルのリストを作成する
self.images = []
self.labels = []
for class_dir in os.listdir(data_path):
class_path = os.path.join(data_path, class_dir)
for image_file in os.listdir(class_path):
image_path = os.path.join(class_path, image_file)
label = int(class_dir)
self.images.append(image_path)
self.labels.append(label)
# データ変換を定義する
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def __getitem__(self, index):
image_path = self.images[index]
image = Image.open(image_path).convert('RGB')
image = self.transforms(image)
label = self.labels[index]
return image, label
def __len__(self):
return len(self.images)
# データセットを作成する
dataset = CustomDataset(root_dir='data')
# データローダーを作成する
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# データセットをループする
for images, labels in dataloader:
# 画像とラベルを処理する
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, root_dir, train=True):
self.root_dir = root_dir
self.train = train
# データセットを読み込む
if self.train:
data_path = os.path.join(root_dir, 'train')
else:
data_path = os.path.join(root_dir, 'val')
# 画像とラベルのリストを作成する
self.images = []
self.labels = []
for class_dir in os.listdir(data_path):
class_path = os.path.join(data_path, class_dir)
for image_file in os.listdir(class_path):
image_path = os.path.join(class_path, image_file)
label = int(class_dir)
self.images.append(image_path)
self.labels.append(label)
# データ変換を定義する
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def __getitem__(self, index):
image_path = self.images[index]
image = Image.open(image_path).convert('RGB')
image = self.transforms(image)
label = self.labels[index]
return image, label
def __len__(self):
return len(self.images)
# データセットを作成する
dataset = CustomDataset(root_dir='data')
# データローダーを作成する
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# データセットをループする
for images, labels in dataloader:
# 画像とラベルを処理する
pass
このコードは次のとおりです。
CustomDataset
クラス:このクラスは、カスタムデータセットを表します。__init__
メソッドは、データセットのルートディレクトリと、トレーニングデータセットか検証データセットかを指定するtrain
引数を取ります。 このメソッドは、画像とラベルのリストを作成し、データ変換を定義します。__getitem__
メソッドは、指定されたインデックスの画像とラベルを返します。__len__
メソッドは、データセット内の画像の数を返します。- データセットの作成:この行は、
CustomDataset
クラスを使用してデータセットを作成します。 - データローダーの作成:この行は、
DataLoader
クラスを使用してデータローダーを作成します。 データローダーは、バッチサイズ、シャッフル、ワーカー数などの引数を取ります。 - データセットのループ処理:このループは、データローダー内の各バッチを反復します。 各バッチには、画像とラベルのテンソルが含まれます。
説明
このコードは、カスタム画像ベースデータセットをロードするための基本的なフレームワークを提供します。 実際のアプリケーションでは、独自の要件に合わせてコードをカスタマイズする必要があります。 たとえば、異なるデータ形式を使用している場合、独自のデータ変換を定義する必要がある場合があります。
- このコードは、Python 3.x と PyTorch 1.x を使用することを前提としています。
- このコードは、画像分類タスク向けに設計されています。 オブジェクト検出やセマンティックセグメンテーションなどの他のタスクには、コードを修正する必要がある場合があります。
torch.utils.data.DatasetFolder
クラスは、画像ディレクトリから画像データセットを自動的にロードするのに役立ちます。 このクラスを使用するには、以下の手順を実行する必要があります。
- 画像をクラスごとに整理されたディレクトリ構造に配置します。
DatasetFolder
クラスを使用してデータセットを作成します。- データローダーを使用してデータセットをロードします。
import torch
from torchvision import transforms
from torch.utils.data import DatasetFolder, DataLoader
# 画像をクラスごとに整理されたディレクトリ構造に配置する
# ...
# データセットを作成する
dataset = DatasetFolder(root='data', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]))
# データローダーを作成する
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# データセットをループする
for images, labels in dataloader:
# 画像とラベルを処理する
pass
カスタムデータローダーを作成する
独自のデータローダーを作成することもできます。 これにより、データの読み込みと処理方法をより細かく制御できます。 カスタムデータローダーを作成するには、以下の手順を実行する必要があります。
torch.utils.data.Dataset
クラスを継承するクラスを作成します。__getitem__
メソッドと__len__
メソッドを実装します。
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, root_dir, train=True):
# ... (CustomDatasetクラスの初期化と同じ)
def __getitem__(self, index):
# ... (CustomDatasetクラスの__getitem__メソッドと同じ)
def __len__(self):
# ... (CustomDatasetクラスの__len__メソッドと同じ)
# データセットを作成する
dataset = CustomDataset(root_dir='data')
# データローダーを作成する
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# データセットをループする
for images, labels in dataloader:
# 画像とラベルを処理する
pass
HDF5 ファイルを使用する
HDF5 ファイルは、データを効率的に保存するためのバイナリファイル形式です。 PyTorchには、HDF5 ファイルからデータをロードするためのツールが含まれています。 この方法を使用するには、以下の手順を実行する必要があります。
- 画像とラベルを HDF5 ファイルに保存します。
import torch
import h5py
from torch.utils.data import DatasetHDF5, DataLoader
# 画像とラベルを HDF5 ファイルに保存する
# ...
# データセットを作成する
dataset = DatasetHDF5('data.h5', 'images', 'labels')
# データローダーを作成する
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# データセットをループする
for images, labels in dataloader:
# 画像とラベルを処理する
pass
python-3.x machine-learning conv-neural-network