Skip to content

Quick Start

Architecture

The disent directory structure:

  • disent/dataset: dataset wrappers, datasets & sampling strategies
    • disent/dataset/data: raw datasets
    • disent/dataset/sampling: sampling strategies for DisentDataset
  • disent/framework: frameworks, including Auto-Encoders and VAEs
  • disent/metric: metrics for evaluating disentanglement using ground truth datasets
  • disent/model: common encoder and decoder models used for VAE research
  • disent/nn: torch components for building models including layers, transforms, losses and general maths
  • disent/schedule: annealing schedules that can be registered to a framework
  • disent/util: helper functions for the rest of the framework

Please Note The API Is Still Unstable ⚠️

Disent is still under active development. Features and APIs are not considered stable, and should be expected to change! A limited set of tests currently exist which will be expanded upon in time.

Examples

dataset/data

Common and custom data for vision based AE, VAE and Disentanglement research.

  • Most data is generated from ground truth factors which is necessary for evaluation using disentanglement metrics. Each image generated from ground truth data has the ground truth variables available.
Example
from disent.dataset.data import XYObjectData

data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb')

print(f'Number of observations: {len(data)} == {data.size}')
print(f'Observation shape: {data.observation_shape}')
print(f'Num Factors: {data.num_factors}')
print(f'Factor Names: {data.factor_names}')
print(f'Factor Sizes: {data.factor_sizes}')

for i, obs in enumerate(data):
    print(
        f'i={i}',
        f'pos: ({", ".join(data.factor_names)}) = {tuple(data.idx_to_pos(i))}',
        f'obs={obs.tolist()}',
        sep=' | ',
    )

dataset

Ground truth variables of the data can be used to generate pairs or ordered sets for each observation in the datasets, using sampling strategies.

Examples
from disent.dataset.data import XYObjectData
from disent.dataset import DisentDataset

# prepare the data
# - DisentDataset is a generic wrapper around torch Datasets that prepares
#   the data for the various frameworks according to some sampling strategy
#   by default this sampling strategy just returns the data at the given idx.
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb')
dataset = DisentDataset(data, transform=None, augment=None)

# iterate over single epoch
for obs in dataset:
    # transform(data[i]) gives 'x_targ', then augment(x_targ) gives 'x'
    (x0,) = obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.nn.transform import ToStandardisedTensor


# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb')
dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor())

# iterate over single epoch
for obs in dataset:
    # singles are contained in tuples of size 1 for compatibility with pairs with size 2
    (x0, x1) = obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairSampler
from disent.nn.transform import ToStandardisedTensor, FftBoxBlur


# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb')
dataset = DisentDataset(data, sampler=GroundTruthPairSampler(), transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0))

# iterate over single epoch
for obs in dataset:
    # if augment is not specified, then the augmented 'x' key does not exist!
    (x0, x1), (x0_targ, x1_targ) = obs['x'], obs['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.nn.transform import ToStandardisedTensor

# prepare the data
data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb')
dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# iterate over single epoch
for batch in dataloader:
    (x0, x1) = batch['x_targ']
    print(x0.dtype, x0.min(), x0.max(), x0.shape)

framework

PytorchLightning modules that contain various AE or VAE implementations.

Examples
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.frameworks.ae import Ae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run  # you can ignore and remove this


# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=Ae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum')
)

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run  # you can ignore and remove this


# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = BetaVae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4)
)

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.frameworks.vae import AdaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.util import is_test_run  # you can ignore and remove this


# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = AdaVae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=AdaVae.cfg(
        optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
        loss_reduction='mean_sum', beta=4, ada_average_mode='gvae', ada_thresh_mode='kl',
    )
)

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

metrics

Various metrics used to evaluate representations learnt by AEs and VAEs.

Example
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.metrics import metric_dci, metric_mig
from disent.util import is_test_run

data = XYObjectData()
dataset = DisentDataset(data, transform=ToStandardisedTensor(), augment=None)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

def make_vae(beta):
    return BetaVae(
        model=AutoEncoder(
            encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
            decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
        ),
        cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), beta=beta)
    )

def train(module):
    trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=is_test_run())
    trainer.fit(module, dataloader)

    # we cannot guarantee which device the representation is on
    get_repr = lambda x: module.encode(x.to(module.device))

    # evaluate
    return {
        **metric_dci(dataset, get_repr, num_train=10 if is_test_run() else 1000, num_test=5 if is_test_run() else 500, boost_mode='sklearn'),
        **metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000),
    }

a_results = train(make_vae(beta=4))
b_results = train(make_vae(beta=0.01))

print('beta=4:   ', a_results)
print('beta=0.01:', b_results)

schedules

Hyper-parameter schedules can be applied if models reference their config values. Such as beta (cfg.beta) in all the BetaVAE derived classes.

A warning will be printed if the hyper-parameter does not exist in the config, instead of crashing.

Example
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.nn.transform import ToStandardisedTensor
from disent.schedule import CyclicSchedule
from disent.util import is_test_run  # you can ignore and remove this

# prepare the data
data = XYObjectData()
dataset = DisentDataset(data, transform=ToStandardisedTensor())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# create the pytorch lightning system
module: pl.LightningModule = BetaVae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=6),
    ),
    cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4)
)

# register the scheduler with the DisentFramework
# - cyclic scheduler from: https://arxiv.org/abs/1903.10145
module.register_schedule('beta', CyclicSchedule(
    period=1024,  # repeat every: trainer.global_step % period
))

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader)

Datasets Without Ground-Truth Factors

You can use datasets that do not have ground truth factors by changing the sampling strategy of DisentDataset, however, metrics cannot be computed.

The following MNIST example uses the builtin RandomSampler.

Example
import os
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
from disent.dataset import DisentDataset
from disent.dataset.sampling import RandomSampler
from disent.frameworks.vae import AdaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderFC, EncoderFC
from disent.nn.transform import ToStandardisedTensor


# modify the mnist dataset to only return images, not labels
class MNIST(datasets.MNIST):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img


# make mnist dataset -- adjust num_samples here to match framework. TODO: add tests that can fail with a warning -- dataset downloading is not always reliable
data_folder   = os.path.abspath(os.path.join(__file__, '../data/dataset'))
dataset_train = DisentDataset(MNIST(data_folder, train=True,  download=True, transform=ToStandardisedTensor()), sampler=RandomSampler(num_samples=2))
dataset_test  =               MNIST(data_folder, train=False, download=True, transform=ToStandardisedTensor())

# create the dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=128, shuffle=True, num_workers=os.cpu_count())
dataloader_test  = DataLoader(dataset=dataset_test,  batch_size=128, shuffle=True, num_workers=os.cpu_count())

# create the model
module = AdaVae(
    model=AutoEncoder(
        encoder=EncoderFC(x_shape=(1, 28, 28), z_size=9, z_multiplier=2),
        decoder=DecoderFC(x_shape=(1, 28, 28), z_size=9),
    ),
    cfg=AdaVae.cfg(
        optimizer='adam', optimizer_kwargs=dict(lr=1e-3),
        beta=4, recon_loss='mse', loss_reduction='mean_sum',  # "mean_sum" is the traditional loss reduction mode, rather than "mean"
    )
)

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=2048)  # callbacks=[VaeLatentCycleLoggingCallback(every_n_steps=250, plt_show=True)]
trainer.fit(module, dataloader_train)

# move back to gpu & manually encode some observation
for xs in tqdm(dataloader_test, desc='Custom Evaluation'):
    zs = module.encode(xs.to(module.device))

Last update: July 1, 2021