Skip to content

ACGAN

Auxillary Classifier Generative Adveraisal Network.

Parameters:

Name Type Description Default
target_size int

Target size for the image to be generated.

required
num_channels int

Number of channels in the dataset image.

required
num_classes int

Number of classes in the dataset.

required
latent_size int

Size of the noise. Default: 100.

required
generator_feature_size int

Number of features for the Generator.

required
discriminator_feature_size int

Number of features for the Discriminator.

required
g_lr float

Learning rate for the Generator.

required
g_betas tuple

Co-efficients for the Generator.

required
d_lr float

Learning rate for the Discriminator.

required
d_betas float

Co-efficients for the Discriminator.

required

generate(self, labels, inputs=None, output_type='tensor')

Generate images for given labels and inputs.

Parameters:

Name Type Description Default
labels torch.Tensor

A tensor of labels for which the model should generate.

required
inputs None or torch.Tensor

Either give a predefined set of inputs or generate randomly.

None
output_type str

Whether to return a tensor of outputs or a reshaped grid.

'tensor'

Returns:

Type Description
torch.Tensor

Depending on output_type, either the raw output tensors or a tensor grid will be returned.

Source code in pytorch_gan_trainer/models/acgan/acgan.py
def generate(self, labels, inputs=None, output_type="tensor"):
    """Generate images for given labels and inputs.

    Arguments:
        labels (torch.Tensor): A tensor of labels for which \
            the model should generate.
        inputs (None or torch.Tensor): Either give a predefined \
            set of inputs or generate randomly.
        output_type (str): Whether to return a tensor \
            of outputs or a reshaped grid.

    Returns:
        torch.Tensor: Depending on output_type, either the raw output tensors \
            or a tensor grid will be returned.
    """
    if inputs is None:
        inputs = torch.randn(size=(labels.size(0), self.latent_size)).to(
            self.device
        )

    self.generator.eval()
    with torch.no_grad():
        outputs = self.generator(inputs, labels)
    self.generator.train()

    if output_type == "tensor":
        return outputs
    if output_type == "image":
        return torchvision.utils.make_grid(outputs.cpu(), normalize=True)

    raise Exception("Invalid return type specified")

load_checkpoint(self, models_path)

Load a previously saved checkpoint.

Parameters:

Name Type Description Default
models_path str

Path to load the previous state.

required

Returns:

Type Description
int

Last processed epoch.

Source code in pytorch_gan_trainer/models/acgan/acgan.py
def load_checkpoint(self, models_path):
    """Load a previously saved checkpoint.

    Arguments:
        models_path (str): Path to load the previous state.

    Returns:
        int: Last processed epoch.
    """
    state = torch.load(models_path, map_location=self.device)

    self.generator.load_state_dict(state["generator"])
    self.discriminator.load_state_dict(state["discriminator"])
    self.g_optim.load_state_dict(state["g_optim"])
    self.d_optim.load_state_dict(state["d_optim"])

    return state["epoch"] + 1

save_checkpoint(self, epoch, models_path)

Creates a checkpoint of the models and optimizers.

Parameters:

Name Type Description Default
epoch int

Current epoch.

required
models_path str

Path to save current state.

required
Source code in pytorch_gan_trainer/models/acgan/acgan.py
def save_checkpoint(self, epoch, models_path):
    """Creates a checkpoint of the models and optimizers.

    Arguments:
        epoch (int): Current epoch.
        models_path (str): Path to save current state.
    """
    torch.save(
        {
            "epoch": epoch,
            "generator": self.generator.state_dict(),
            "discriminator": self.discriminator.state_dict(),
            "g_optim": self.g_optim.state_dict(),
            "d_optim": self.d_optim.state_dict(),
        },
        models_path,
    )

set_device(self, device)

Changes the device on which the models reside.

Parameters:

Name Type Description Default
device torch.device

Device to which the models should switch.

required
Source code in pytorch_gan_trainer/models/acgan/acgan.py
def set_device(self, device):
    """Changes the device on which the models reside.

    Arguments:
        device (torch.device): Device to which the models should switch.

    """
    self.device = device
    self.generator.to(device)
    self.discriminator.to(device)

train(self, epochs, dataloader, epoch_start=0, output_batch=64, output_epochs=1, output_path='./outputs', project=None, name=None, config={}, models_path=None)

Training loop for ACGAN.

Parameters:

Name Type Description Default
epochs str

Number of epochs for training.

required
dataloader torch.utils.data.DataLoader

PyTorch DataLoader containing the dataset.

required
epoch_start int

The epoch from which training should start.

0
output_batch int

The batch size for the outputs.

64
output_epochs int

The frequency for which outputs will be generated (per epoch).

1
output_path str

The location at which the outputs will be saved. If output_path is wandb, then Weights and Biases will be configured.

'./outputs'
project str

Project name (Weights and Biases only).

None
name str

Experiment name (Weights and Biases only).

None
config dict

Dictionary containing the configuration settings.

{}
models_path str

Path at which (if provided) the checkpoints will be saved.

None
Source code in pytorch_gan_trainer/models/acgan/acgan.py
def train(
    self,
    epochs,
    dataloader,
    epoch_start=0,
    output_batch=64,
    output_epochs=1,
    output_path="./outputs",
    project=None,
    name=None,
    config={},
    models_path=None,
):
    """Training loop for ACGAN.

    Arguments:
        epochs (str): Number of epochs for training.
        dataloader (torch.utils.data.DataLoader): \
            PyTorch DataLoader containing the dataset.
        epoch_start (int): The epoch from which training should start.
        output_batch (int): The batch size for the outputs.
        output_epochs (int): The frequency for which outputs \
            will be generated (per epoch).
        output_path (str): The location at which the outputs will be saved. \
            If output_path is wandb, then Weights and Biases will be configured.
        project (str): Project name (Weights and Biases only).
        name (str): Experiment name (Weights and Biases only).
        config (dict): Dictionary containing the configuration settings.
        models_path (str): Path at which (if provided) \
            the checkpoints will be saved.
    """

    if output_path == "wandb":
        if project is None:
            raise Exception("No project name specified")
        authorize_wandb(project, name, config)

    adversarial_loss = nn.BCELoss().to(self.device)
    auxillary_loss = nn.CrossEntropyLoss().to(self.device)

    # Fixed input noise
    fixed_noise = torch.randn(size=(self.num_classes, self.latent_size)).to(
        self.device
    )
    fixed_labels = torch.tensor([i for i in range(self.num_classes)]).to(
        self.device
    )

    # Set tdqm for epoch progress
    pbar = tqdm()

    epoch_end = epochs + epoch_start
    for epoch in range(epoch_start, epoch_end):
        print(f"Epoch: {epoch + 1} / {epoch_end}")
        pbar.reset(total=len(dataloader))

        # Setting up losses
        discriminator_total_losses = []
        generator_total_losses = []
        accuracy_history = []

        for real_images, real_labels in dataloader:

            # Current batch size
            current_batch_size = real_images.size()[0]

            # Convert to cuda
            real_images = real_images.to(self.device)
            real_labels = real_labels.to(self.device)

            # For real vs fake
            real_validity = torch.ones(current_batch_size, 1).to(self.device)
            fake_validity = torch.zeros(current_batch_size, 1).to(self.device)

            # Training Generator
            self.generator.zero_grad()

            # Generate fake images
            input_noise = torch.randn(
                size=(current_batch_size, self.latent_size)
            ).to(self.device)
            fake_labels = torch.randint(
                self.num_classes, size=(current_batch_size,)
            ).to(self.device)

            fake_images = self.generator(input_noise, fake_labels)

            # Calculate Generator loss
            (
                discriminator_fake_validity,
                discriminator_fake_labels,
            ) = self.discriminator(fake_images)

            generator_total_loss = (
                adversarial_loss(discriminator_fake_validity, real_validity)
                + auxillary_loss(discriminator_fake_labels, fake_labels)
            ) / 2
            generator_total_loss.backward()
            self.g_optim.step()
            generator_total_losses.append(generator_total_loss)

            # Training Discriminator
            self.discriminator.zero_grad()

            # Loss for real images
            (
                discriminator_real_validity,
                discriminator_real_labels,
            ) = self.discriminator(real_images)
            discriminator_real_loss = (
                adversarial_loss(discriminator_real_validity, real_validity)
                + auxillary_loss(discriminator_real_labels, real_labels)
            ) / 2

            # Loss for fake images
            (
                discriminator_fake_validity,
                discriminator_fake_labels,
            ) = self.discriminator(fake_images.detach())
            discriminator_fake_loss = (
                adversarial_loss(discriminator_fake_validity, fake_validity)
                + auxillary_loss(discriminator_fake_labels, fake_labels)
            ) / 2

            # Total loss
            discriminator_total_loss = (
                discriminator_real_loss + discriminator_fake_loss
            )
            discriminator_total_loss.backward()
            self.d_optim.step()
            discriminator_total_losses.append(discriminator_total_loss)

            # Calculate Discriminator Accuracy
            predictions = np.concatenate(
                [
                    discriminator_real_labels.data.cpu().numpy(),
                    discriminator_fake_labels.data.cpu().numpy(),
                ],
                axis=0,
            )
            true_values = np.concatenate(
                [real_labels.cpu().numpy(), fake_labels.cpu().numpy()], axis=0
            )
            discriminator_accuracy = np.mean(
                np.argmax(predictions, axis=1) == true_values
            )
            accuracy_history.append(discriminator_accuracy)

            # Update tqdm
            pbar.update()

        d_total_loss = torch.mean(torch.FloatTensor(discriminator_total_losses))
        accuracy = np.mean(accuracy_history)
        g_total_loss = torch.mean(torch.FloatTensor(generator_total_losses))
        print(
            "Discriminator Total Loss: {:.3f}, Discriminator Accuracy: {:.3f}, \
                Generator Total Loss: {:.3f}".format(
                d_total_loss, accuracy, g_total_loss
            )
        )

        if output_path == "wandb":
            log_wandb(
                {
                    "Discriminator Total Loss": d_total_loss,
                    "Discriminator Accuracy": accuracy,
                    "Generator Total Loss": g_total_loss,
                },
                epoch + 1,
            )

        if (epoch + 1) % output_epochs == 0:
            save_output(
                epoch + 1, output_path, fixed_noise, self.generator, fixed_labels
            )
            if models_path:
                self.save_checkpoint(epoch, models_path)

        pbar.refresh()