Skip to content

Datasets

prepare_dataloader(dataset_path, image_size, batch_size)

Prepares a PyTorch DataLoader from a given path with a specific image and batch size. If a specific dataset type is given, the path will not be read and a torchvision dataset will be loaded instead.

Parameters:

Name Type Description Default
dataset_path str

The location of the dataset.

required
image_size int

The image size of the images in the DataLoader.

required
batch_size int

The batch size to be set for the DataLoader.

required

Returns:

Type Description
torch.utils.data.DataLoader

A DataLoader with the specified image size and batch size, Length of the number of classes found in the dataset.

Source code in pytorch_gan_trainer/datasets.py
def prepare_dataloader(dataset_path, image_size, batch_size):
    """Prepares a PyTorch DataLoader from a given path \
        with a specific image and batch size. \
        If a specific dataset type is given, \
            the path will not be read and \
                a torchvision dataset will be loaded instead.

    Arguments:
        dataset_path (str): The location of the dataset.
        image_size (int): The image size of the images in the DataLoader.
        batch_size (int): The batch size to be set for the DataLoader.

    Returns:
        torch.utils.data.DataLoader: A DataLoader \
            with the specified image size and batch size, \
                Length of the number of classes found in the dataset.
    """

    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5), std=(0.5)),
        ]
    )

    if _check_path(dataset_path):
        dataset = torchvision.datasets.ImageFolder(dataset_path, transform=transform)
    else:
        dataset = _get_torch_dataset(dataset_path, transform)

    dataloader = DataLoader(
        dataset, batch_size, shuffle=True, pin_memory=True, num_workers=4
    )

    return dataloader, dataset.classes

Datasets

The supported datasets are:

  1. CIFAR

  2. MNIST

  3. Fashion-MNIST