Datasets
prepare_dataloader(dataset_path, image_size, batch_size)
¶
Prepares a PyTorch DataLoader from a given path with a specific image and batch size. If a specific dataset type is given, the path will not be read and a torchvision dataset will be loaded instead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_path |
str |
The location of the dataset. |
required |
image_size |
int |
The image size of the images in the DataLoader. |
required |
batch_size |
int |
The batch size to be set for the DataLoader. |
required |
Returns:
Type | Description |
---|---|
torch.utils.data.DataLoader |
A DataLoader with the specified image size and batch size, Length of the number of classes found in the dataset. |
Source code in pytorch_gan_trainer/datasets.py
def prepare_dataloader(dataset_path, image_size, batch_size):
"""Prepares a PyTorch DataLoader from a given path \
with a specific image and batch size. \
If a specific dataset type is given, \
the path will not be read and \
a torchvision dataset will be loaded instead.
Arguments:
dataset_path (str): The location of the dataset.
image_size (int): The image size of the images in the DataLoader.
batch_size (int): The batch size to be set for the DataLoader.
Returns:
torch.utils.data.DataLoader: A DataLoader \
with the specified image size and batch size, \
Length of the number of classes found in the dataset.
"""
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5)),
]
)
if _check_path(dataset_path):
dataset = torchvision.datasets.ImageFolder(dataset_path, transform=transform)
else:
dataset = _get_torch_dataset(dataset_path, transform)
dataloader = DataLoader(
dataset, batch_size, shuffle=True, pin_memory=True, num_workers=4
)
return dataloader, dataset.classes
Datasets¶
The supported datasets are: