Skip to content

Overview

PyTorchLightning disentanglement library implementing various VAEs that can easily be run with Hydra Config.

Various other unique optional features exist, including data augmentations, as well as the first (?) unofficial implementation of the tensorflow based Ada-GVAE.

Goals

  • Easy configuration and running of custom experiments.
  • Flexible and composable components for vision based AE research.
  • Challenge common opinions and assumptions in AE research through data.

Disent Structure

The disent framework can be decomposed into various parts:

data

Common and custom data for vision based AE, VAE and Disentanglement research.

  • Most data is generated from ground truth factors which is necessary for evaluation using disentanglement metrics. Each image generated from ground truth data has the ground truth variables available.
Example
from disent.data.groundtruth import XYSquaresData

data = XYSquaresData(square_size=1, image_size=2, num_squares=2)

print(f'Number of observations: {len(data)} == {data.size}')
print(f'Observation shape: {data.observation_shape}')
print(f'Num Factors: {data.num_factors}')
print(f'Factor Names: {data.factor_names}')
print(f'Factor Sizes: {data.factor_sizes}')

for i, obs in enumerate(data):
    print(
        f'i={i}',
        f'pos: ({", ".join(data.factor_names)}) = {tuple(data.idx_to_pos(i))}',
        f'obs={obs.tolist()}',
        sep=' | ',
    )

dataset

Wrappers for the aforementioned data. Ground truth variables of the data can be used to generate pairs or ordered sets for each observation in the datasets.

Examples
from torch.utils.data import Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset


data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2)
dataset: Dataset = GroundTruthDataset(data, transform=None, augment=None)

for obs in dataset:
    # transform is applied to data to get x_targ, then augment to get x
    # if augment is None then 'x' doesn't exist in the obs
    (x0,) = obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from torch.utils.data import Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDatasetPairs
from disent.nn.transform import ToStandardisedTensor


data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2)
dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None)

for obs in dataset:
    # singles are contained in tuples of size 1 for compatibility with pairs with size 2
    (x0, x1) = obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from torch.utils.data import Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDatasetPairs
from disent.nn.transform import FftBoxBlur, ToStandardisedTensor


data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2)
dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0))

for obs in dataset:
    # if augment is not None so the augmented 'x' exists in the observation
    (x0, x1), (x0_targ, x1_targ) = obs['x'], obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from torch.utils.data import DataLoader, Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDatasetPairs
from disent.nn.transform import ToStandardisedTensor


data: GroundTruthData = XYSquaresData(square_size=1, image_size=2, num_squares=2)
dataset: Dataset = GroundTruthDatasetPairs(data, transform=ToStandardisedTensor(), augment=None)
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

for batch in dataloader:
    (x0, x1) = batch['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)

frameworks

PytorchLightning modules that contain various AE or VAE implementations.

Examples
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.ae import Ae
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run


data: GroundTruthData = XYSquaresData()
dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

module: pl.LightningModule = Ae(
    make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=Ae.cfg()
)

trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae import BetaVae
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run


data: GroundTruthData = XYSquaresData()
dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

module: pl.LightningModule = BetaVae(
    make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=BetaVae.cfg(beta=4)
)

trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDatasetOrigWeakPairs
from disent.frameworks.vae import AdaVae
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run


data: GroundTruthData = XYSquaresData()
dataset: Dataset = GroundTruthDatasetOrigWeakPairs(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

module: pl.LightningModule = AdaVae(
    make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=AdaVae.cfg(beta=4, ada_average_mode='gvae', ada_thresh_mode='kl')
)

trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

metrics

Various metrics used to evaluate representations learnt by AEs and VAEs.

Example
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.data.groundtruth import XYObjectData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run


data = XYObjectData()
dataset = GroundTruthDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

def make_vae(beta):
    return BetaVae(
        make_optimizer_fn=lambda params: Adam(params, lr=5e-3),
        make_model_fn=lambda: AutoEncoder(
            encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
            decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
        ),
        cfg=BetaVae.cfg(beta=beta)
    )

def train(module):
    trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=is_test_run())
    trainer.fit(module, dataloader)

    # we cannot guarantee which device the representation is on
    get_repr = lambda x: module.encode(x.to(module.device))

    # evaluate
    return {
        **metric_dci(dataset, get_repr, num_train=10 if is_test_run() else 1000, num_test=5 if is_test_run() else 500, boost_mode='sklearn'),
        **metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000),
    }

a_results = train(make_vae(beta=4))
b_results = train(make_vae(beta=0.01))

print('beta=4:   ', a_results)
print('beta=0.01:', b_results)

schedules

Various hyper-parameter schedules can be applied if models reference their config values directly. Such as beta (cfg.beta) in all the BetaVAE derived classes.

A warning will be printed if the hyper-parameter does not exist in the config, instead of crashing.

Example
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from disent.data.groundtruth import GroundTruthData, XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.frameworks.vae import BetaVae
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.schedule import CyclicSchedule
from disent.util import is_test_run


data: GroundTruthData = XYSquaresData()
dataset: Dataset = GroundTruthDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

module: pl.LightningModule = BetaVae(
    make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=BetaVae.cfg(beta=4)
)

# https://arxiv.org/abs/1903.10145
module.register_schedule('beta', CyclicSchedule(
    period=1024,  # repeat every: trainer.global_step % period
))

trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

Datasets Without Ground-Truth Factors

Using datasets that do not have ground truth factors require custom wrappers with custom sampling procedures, however metrics cannot be computed over them.

We can implement an MNIST example using the builtin random sampler.

Example
import os
from collections import Sequence

import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

from disent.dataset.random import RandomDataset
from disent.frameworks.vae import AdaVae
from disent.model.ae import DecoderConv64Alt, EncoderConv64Alt
from disent.model import AutoEncoder
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run


# modify the mnist dataset to only return observations, not labels
class MNIST(datasets.MNIST, Sequence):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img


# make mnist dataset -- adjust num_samples here to match framework. TODO: add tests that can fail with a warning -- dataset downloading is not always reliable
data_folder   = os.path.abspath(os.path.join(__file__, '../data/dataset'))
dataset_train = RandomDataset(MNIST(data_folder, train=True,  download=True, transform=ToStandardisedTensor(size=64)), num_samples=2)
dataset_test  =               MNIST(data_folder, train=False, download=True, transform=ToStandardisedTensor(size=64))

# create the dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, shuffle=True)
dataloader_test  = DataLoader(dataset=dataset_test,  batch_size=64, shuffle=True)

# create the model
module = AdaVae(
    make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
    make_model_fn=lambda: AutoEncoder(
        encoder=EncoderConv64Alt(x_shape=(1, 64, 64), z_size=9, z_multiplier=2),
        decoder=DecoderConv64Alt(x_shape=(1, 64, 64), z_size=9),
    ),
    cfg=AdaVae.cfg(beta=4, recon_loss='mse', loss_reduction='mean_sum')  # "mean_sum" is the traditional reduction, rather than "mean"
)

# train model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=65535, fast_dev_run=is_test_run())  # callbacks=[VaeLatentCycleLoggingCallback(every_n_steps=250, plt_show=True)]
trainer.fit(module, dataloader_train)

# move back to gpu & manually encode some observation
for xs in tqdm(dataloader_test, desc='Custom Evaluation'):
    zs = module.encode(xs.to(module.device))
    if is_test_run(): break

Last update: May 24, 2021