Skip to content

cca_zoo.deep

Deep CCA variants. Requires pip install cca-zoo[deep].


Base class

BaseDeep

BaseDeep(latent_dimensions: int, encoders: list[Module], lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: LightningModule

Base class for deep multiview CCA models using PyTorch Lightning.

Subclasses override :meth:loss to implement the specific objective function. Training is handled by a :class:lightning.Trainer.

The sklearn-compatible interface (fit, transform, score) is provided for convenience, wrapping the Lightning training loop.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
lr float

Learning rate for the Adam optimiser. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Small constant for numerical stability. Default is 1e-6.

1e-06
Source code in cca_zoo/deep/_base.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__()
    self.latent_dimensions = latent_dimensions
    self.lr = lr
    self.max_epochs = max_epochs
    self.eps = eps
    self.encoders = nn.ModuleList(encoders)

forward

forward(views: list[Tensor]) -> list[torch.Tensor]

Encode all views into latent representations.

Parameters:

Name Type Description Default
views list[Tensor]

List of tensors, each (batch_size, n_features_i).

required

Returns:

Type Description
list[Tensor]

List of tensors, each (batch_size, latent_dimensions).

Source code in cca_zoo/deep/_base.py
def forward(self, views: list[torch.Tensor]) -> list[torch.Tensor]:
    """Encode all views into latent representations.

    Args:
        views: List of tensors, each (batch_size, n_features_i).

    Returns:
        List of tensors, each (batch_size, latent_dimensions).
    """
    return [enc(v) for enc, v in zip(self.encoders, views)]

transform

transform(loader: DataLoader) -> list[np.ndarray]

Project all samples in a DataLoader into the latent space.

Parameters:

Name Type Description Default
loader DataLoader

DataLoader yielding batches with a "views" key.

required

Returns:

Type Description
list[ndarray]

List of numpy arrays, each (n_samples, latent_dimensions).

Source code in cca_zoo/deep/_base.py
@torch.no_grad()
def transform(self, loader: torch.utils.data.DataLoader) -> list[np.ndarray]:
    """Project all samples in a DataLoader into the latent space.

    Args:
        loader: DataLoader yielding batches with a ``"views"`` key.

    Returns:
        List of numpy arrays, each (n_samples, latent_dimensions).
    """
    self.eval()
    all_reprs: list[list[torch.Tensor]] = []
    for batch in loader:
        views_dev = [v.to(self.device) for v in batch["views"]]
        z = self(views_dev)
        all_reprs.append([zi.cpu() for zi in z])
    # Concatenate batches per view
    stacked = [
        torch.cat([b[i] for b in all_reprs], dim=0)
        for i in range(len(all_reprs[0]))
    ]
    return [t.numpy() for t in stacked]

score

score(loader: DataLoader) -> np.ndarray

Return average pairwise canonical correlations after linear CCA.

Parameters:

Name Type Description Default
loader DataLoader

DataLoader with a "views" key.

required

Returns:

Type Description
ndarray

Array of shape (latent_dimensions,).

Source code in cca_zoo/deep/_base.py
def score(self, loader: torch.utils.data.DataLoader) -> np.ndarray:
    """Return average pairwise canonical correlations after linear CCA.

    Args:
        loader: DataLoader with a ``"views"`` key.

    Returns:
        Array of shape ``(latent_dimensions,)``.
    """
    representations = self.transform(loader)
    return (
        MCCA(latent_dimensions=self.latent_dimensions)
        .fit(representations)
        .score(representations)
    )

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Compute the training loss for one mini-batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dictionary with key "views" (list of tensors) and optionally "independent_views".

required
batch_idx int

Batch index (unused).

required

Returns:

Type Description
Tensor

Scalar loss tensor.

Source code in cca_zoo/deep/_base.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the training loss for one mini-batch.

    Args:
        batch: Dictionary with key ``"views"`` (list of tensors) and
            optionally ``"independent_views"``.
        batch_idx: Batch index (unused).

    Returns:
        Scalar loss tensor.
    """
    representations = self(batch["views"])
    ind_repr = (
        self(batch["independent_views"])
        if batch.get("independent_views") is not None
        else None
    )
    loss_dict = self.loss(representations, ind_repr)
    for k, v in loss_dict.items():
        self.log(
            f"train/{k}",
            v,
            on_step=False,
            on_epoch=True,
            batch_size=batch["views"][0].shape[0],
        )
    return loss_dict["objective"]

validation_step

validation_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Compute the validation loss for one mini-batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dictionary with "views" key.

required
batch_idx int

Batch index (unused).

required

Returns:

Type Description
Tensor

Scalar loss tensor.

Source code in cca_zoo/deep/_base.py
def validation_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the validation loss for one mini-batch.

    Args:
        batch: Dictionary with ``"views"`` key.
        batch_idx: Batch index (unused).

    Returns:
        Scalar loss tensor.
    """
    representations = self(batch["views"])
    loss_dict = self.loss(representations)
    for k, v in loss_dict.items():
        self.log(
            f"val/{k}",
            v,
            on_step=False,
            on_epoch=True,
            batch_size=batch["views"][0].shape[0],
        )
    return loss_dict["objective"]

configure_optimizers

configure_optimizers() -> torch.optim.Optimizer

Create the Adam optimiser.

Returns:

Type Description
Optimizer

Adam optimiser with the configured learning rate.

Source code in cca_zoo/deep/_base.py
def configure_optimizers(self) -> torch.optim.Optimizer:
    """Create the Adam optimiser.

    Returns:
        Adam optimiser with the configured learning rate.
    """
    return torch.optim.Adam(self.parameters(), lr=self.lr)

Models

DCCA

DCCA(latent_dimensions: int, encoders: list[Module], objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: BaseDeep

Deep Canonical Correlation Analysis with a pluggable objective.

Trains two (or more) neural network encoders to maximise canonical correlation between their outputs. The objective function is controlled by the objective parameter, which defaults to the Andrew 2013 CCALoss.

The model is a :class:lightning.pytorch.LightningModule and is trained via a :class:lightning.Trainer.

Reference

Andrew, G., et al. "Deep canonical correlation analysis." ICML 2013.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects mapping each view to the latent space.

required
objective Module | None

Differentiable loss module operating on a list of latent tensors. If None, defaults to :class:~cca_zoo.deep.objectives.CCALoss.

None
lr float

Learning rate for the Adam optimiser. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Regularisation parameter passed to the default CCALoss when objective is None. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = DCCA(latent_dimensions=4, encoders=[enc1, enc2])

Source code in cca_zoo/deep/_dcca.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.objective: nn.Module = CCALoss(eps=eps) if objective is None else objective

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the DCCA training objective.

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded views from the current batch, each of shape (batch_size, latent_dimensions).

required
independent_representations list[Tensor] | None

Optional second set of encodings (unused in the base DCCA formulation).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with key "objective" containing the scalar

dict[str, Tensor]

loss tensor to minimise.

Source code in cca_zoo/deep/_dcca.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the DCCA training objective.

    Args:
        representations: Encoded views from the current batch, each
            of shape (batch_size, latent_dimensions).
        independent_representations: Optional second set of encodings
            (unused in the base DCCA formulation).

    Returns:
        Dictionary with key ``"objective"`` containing the scalar
        loss tensor to minimise.
    """
    return {"objective": self.objective(representations)}

DCCA_EY

DCCA_EY(latent_dimensions: int, encoders: list[Module], objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

DCCA using the EigenGame / Eckart-Young (EY) objective.

The EY objective for CCA can be written as::

L = -tr(2 * C) + tr(V @ V)

where C is the averaged cross-covariance and V is the averaged auto-covariance of the latent representations. When independent_representations are provided, the penalty term becomes tr(V @ V_ind) to decouple estimation of the two quantities (as in the EigenGame formulation).

Reference

Chapman, J., Aguila, A. L., & Wells, L. "A Generalised EigenGame with Extensions to Multiview Representation Learning." arXiv:2211.11323 (2022).

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
objective Module | None

Ignored; the EY objective is fixed for this class. Accepted for API compatibility but overridden internally.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Regularisation for numerical stability. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = DCCA_EY(latent_dimensions=4, encoders=[enc1, enc2])

Source code in cca_zoo/deep/_dcca.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.objective: nn.Module = CCALoss(eps=eps) if objective is None else objective

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the EigenGame / Eckart-Young CCA loss.

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded views from the current batch.

required
independent_representations list[Tensor] | None

Optional second set of encodings for the EigenGame penalty term. When provided the penalty is tr(V @ V_ind) instead of tr(V @ V).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with keys "objective", "rewards", and

dict[str, Tensor]

"penalties".

Source code in cca_zoo/deep/_dcca_ey.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the EigenGame / Eckart-Young CCA loss.

    Args:
        representations: Encoded views from the current batch.
        independent_representations: Optional second set of encodings
            for the EigenGame penalty term.  When provided the penalty
            is ``tr(V @ V_ind)`` instead of ``tr(V @ V)``.

    Returns:
        Dictionary with keys ``"objective"``, ``"rewards"``, and
        ``"penalties"``.
    """
    c, v = _cca_cv(representations)
    rewards = torch.trace(2.0 * c)
    if independent_representations is None:
        penalties = torch.trace(v @ v)
    else:
        _, v_ind = _cca_cv(independent_representations)
        penalties = torch.trace(v @ v_ind)
    return {
        "objective": -rewards + penalties,
        "rewards": rewards,
        "penalties": penalties,
    }

DCCA_NOI

DCCA_NOI(latent_dimensions: int, encoders: list[Module], rho: float = 0.1, objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

Deep CCA via Non-linear Orthogonal Iterations.

Uses batch whitening to approximate the CCA whitening step stochastically. The loss enforces each view's representations to match the whitened version of the other view's representations (with a stop-gradient on the whitened targets).

Reference

Wang, W., et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." Allerton 2015. IEEE.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
rho float

Exponential moving average momentum for the batch whitening layers. Must be in [0, 1]. Default is 0.1.

0.1
objective Module | None

Ignored; the NOI loss is fixed. Accepted for API compatibility.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Regularisation for the whitening layers. Default is 1e-6.

1e-06

Raises:

Type Description
ValueError

If rho is not in [0, 1].

Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = DCCA_NOI(latent_dimensions=4, encoders=[enc1, enc2], rho=0.1)

Source code in cca_zoo/deep/_dcca_noi.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    rho: float = 0.1,
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    if rho < 0.0 or rho > 1.0:
        raise ValueError(f"rho must be in [0, 1], got {rho}.")
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        objective=objective,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.rho = rho
    self.mse = nn.MSELoss(reduction="sum")
    self.bws = nn.ModuleList(
        [_BatchWhiten(latent_dimensions, momentum=rho, eps=eps) for _ in encoders]
    )

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the NOI loss.

Each view's representations are pushed towards the whitened representation of the other view (stop-gradient on the target).

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded views from the current batch, each of shape (batch_size, latent_dimensions).

required
independent_representations list[Tensor] | None

Unused; present for API compatibility.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with key "objective".

Source code in cca_zoo/deep/_dcca_noi.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the NOI loss.

    Each view's representations are pushed towards the whitened
    representation of the other view (stop-gradient on the target).

    Args:
        representations: Encoded views from the current batch, each
            of shape (batch_size, latent_dimensions).
        independent_representations: Unused; present for API
            compatibility.

    Returns:
        Dictionary with key ``"objective"``.
    """
    whitened = [bw(r) for r, bw in zip(representations, self.bws)]
    total = torch.tensor(0.0, device=representations[0].device)
    n_views = len(representations)
    for i in range(n_views):
        for j in range(n_views):
            if i != j:
                total = total + self.mse(representations[i], whitened[j].detach())
    return {"objective": total}

DCCA_SDL

DCCA_SDL(latent_dimensions: int, encoders: list[Module], lam: float = 0.5, objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

Deep CCA via Stochastic Decorrelation Loss.

Combines an MSE alignment loss between views with a within-view decorrelation penalty. Batch normalisation is applied to each encoder output before the loss is computed.

The total loss is::

L = MSE(z1, z2) + lam * (SDL(z1) + SDL(z2))

where SDL(z) = mean|off-diag(Cov(z))|.

Reference

Chang, X., Xiang, T., & Hospedales, T. M. "Scalable and effective deep CCA via soft decorrelation." CVPR 2018.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
lam float

Weight of the SDL decorrelation penalty. Default is 0.5.

0.5
objective Module | None

Ignored; the SDL loss is fixed. Accepted for API compatibility.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Regularisation for numerical stability. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = DCCA_SDL(latent_dimensions=4, encoders=[enc1, enc2], lam=0.5)

Source code in cca_zoo/deep/_dcca_sdl.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    lam: float = 0.5,
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        objective=objective,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.lam = lam
    self.bns = nn.ModuleList(
        [nn.BatchNorm1d(latent_dimensions, affine=False) for _ in encoders]
    )

forward

forward(views: list[Tensor]) -> list[torch.Tensor]

Encode views and apply batch normalisation.

Parameters:

Name Type Description Default
views list[Tensor]

List of input tensors, one per view.

required

Returns:

Type Description
list[Tensor]

List of batch-normalised latent tensors.

Source code in cca_zoo/deep/_dcca_sdl.py
def forward(self, views: list[torch.Tensor]) -> list[torch.Tensor]:
    """Encode views and apply batch normalisation.

    Args:
        views: List of input tensors, one per view.

    Returns:
        List of batch-normalised latent tensors.
    """
    return [bn(enc(v)) for enc, bn, v in zip(self.encoders, self.bns, views)]

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the SDL loss.

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded and batch-normalised views from the current batch, each of shape (batch_size, latent_dimensions).

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with keys "objective", "l2", and "sdl".

Source code in cca_zoo/deep/_dcca_sdl.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the SDL loss.

    Args:
        representations: Encoded and batch-normalised views from the
            current batch, each of shape (batch_size, latent_dimensions).
        independent_representations: Unused.

    Returns:
        Dictionary with keys ``"objective"``, ``"l2"``, and ``"sdl"``.
    """
    l2 = F.mse_loss(representations[0], representations[1])
    sdl = torch.stack([_sdl_loss(r) for r in representations]).sum()
    return {
        "objective": l2 + self.lam * sdl,
        "l2": l2,
        "sdl": sdl,
    }

DCCAE

DCCAE(latent_dimensions: int, encoders: list[Module], decoders: list[Module], lam: float = 0.5, objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: BaseDeep

Deep CCA with Autoencoders.

Extends DCCA by adding per-view reconstruction losses. The total objective is a convex combination of the CCA loss and the summed MSE reconstruction losses::

L = lam * sum_i MSE(x_i, decoder_i(encoder_i(x_i)))
    + (1 - lam) * CCALoss(z_1, ..., z_V)
Reference

Wang, W., et al. "On deep multi-view representation learning." ICML 2015.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects mapping each view to the latent space.

required
decoders list[Module]

List of :class:torch.nn.Module objects mapping the latent space back to each view's input space.

required
lam float

Weight for the reconstruction term. Must be in [0, 1]. When 0 the model reduces to DCCA; when 1 it is a pure autoencoder. Default is 0.5.

0.5
objective Module | None

Differentiable CCA loss operating on a list of latent tensors. Defaults to :class:~cca_zoo.deep.objectives.CCALoss.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Ridge regularisation for the CCA loss. Default is 1e-6.

1e-06

Raises:

Type Description
ValueError

If lam is not in [0, 1].

Example

import torch import torch.nn as nn enc1, enc2 = nn.Linear(10, 4), nn.Linear(8, 4) dec1, dec2 = nn.Linear(4, 10), nn.Linear(4, 8) model = DCCAE( ... latent_dimensions=4, ... encoders=[enc1, enc2], ... decoders=[dec1, dec2], ... )

Source code in cca_zoo/deep/_dccae.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    decoders: list[nn.Module],
    lam: float = 0.5,
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    if lam < 0.0 or lam > 1.0:
        raise ValueError(f"lam must be in [0, 1], got {lam}.")
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.lam = lam
    self.decoders = nn.ModuleList(decoders)
    self.objective: nn.Module = CCALoss(eps=eps) if objective is None else objective

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the DCCAE objective (CCA + reconstruction).

This method does not have access to the original views for reconstruction. Override training_step or call :meth:_full_loss if reconstruction targets are needed.

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded views from the current batch.

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with key "objective" containing the CCA loss

dict[str, Tensor]

(reconstruction is not computed here without raw views).

Source code in cca_zoo/deep/_dccae.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the DCCAE objective (CCA + reconstruction).

    This method does not have access to the original views for
    reconstruction.  Override ``training_step`` or call
    :meth:`_full_loss` if reconstruction targets are needed.

    Args:
        representations: Encoded views from the current batch.
        independent_representations: Unused.

    Returns:
        Dictionary with key ``"objective"`` containing the CCA loss
        (reconstruction is not computed here without raw views).
    """
    return {"objective": self.objective(representations)}

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Training step that includes reconstruction loss.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dictionary with key "views" (list of tensors).

required
batch_idx int

Batch index (unused).

required

Returns:

Type Description
Tensor

Scalar loss tensor.

Source code in cca_zoo/deep/_dccae.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Training step that includes reconstruction loss.

    Args:
        batch: Dictionary with key ``"views"`` (list of tensors).
        batch_idx: Batch index (unused).

    Returns:
        Scalar loss tensor.
    """
    views = batch["views"]
    representations = self(views)
    reconstructions = self._decode(representations)

    cca_loss = self.objective(representations)
    recon_loss = torch.stack(
        [F.mse_loss(x, r) for x, r in zip(views, reconstructions)]
    ).sum()
    objective = (1.0 - self.lam) * cca_loss + self.lam * recon_loss

    loss_dict = {
        "objective": objective,
        "cca": cca_loss,
        "reconstruction": recon_loss,
    }
    for k, v in loss_dict.items():
        self.log(
            f"train/{k}",
            v,
            on_step=False,
            on_epoch=True,
            batch_size=views[0].shape[0],
        )
    return objective

DVCCA

DVCCA(latent_dimensions: int, encoders: list[Module], decoders: list[Module], lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: BaseDeep

Deep Variational Canonical Correlation Analysis.

A variational autoencoder framework for multiview data. Each encoder maps a view to a 2 * latent_dimensions output, which is split into a posterior mean mu and log-variance log_var. A shared latent code z is sampled via the reparameterisation trick and then decoded to reconstruct all views.

The training objective is the negative ELBO::

L = sum_i MSE(x_i, decoder_i(z)) + KL(q(z|X) || p(z))

where q(z|X) = N(sum_i mu_i, diag(exp(sum_i log_var_i))) and p(z) = N(0, I).

Reference

Wang, W., et al. "Deep variational canonical correlation analysis." arXiv:1610.03454 (2016).

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects each mapping a view to a vector of size 2 * latent_dimensions (first half is mu, second half is log_var).

required
decoders list[Module]

List of :class:torch.nn.Module objects mapping the latent vector back to each view's input space.

required
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Regularisation for numerical stability. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn

Encoders output 2 * latent_dimensions

enc1 = nn.Linear(10, 8) enc2 = nn.Linear(10, 8) dec1 = nn.Linear(4, 10) dec2 = nn.Linear(4, 10) model = DVCCA( ... latent_dimensions=4, ... encoders=[enc1, enc2], ... decoders=[dec1, dec2], ... )

Source code in cca_zoo/deep/_dvcca.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    decoders: list[nn.Module],
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.decoders = nn.ModuleList(decoders)

forward

forward(views: list[Tensor]) -> list[torch.Tensor]

Encode views and return a list of latent representations (mu).

At inference time this returns the posterior mean as the point estimate of the latent code for each view independently.

Parameters:

Name Type Description Default
views list[Tensor]

List of input tensors, one per view.

required

Returns:

Type Description
list[Tensor]

List with one tensor of shape (batch_size, latent_dimensions)

list[Tensor]

representing the shared posterior mean.

Source code in cca_zoo/deep/_dvcca.py
def forward(self, views: list[torch.Tensor]) -> list[torch.Tensor]:
    """Encode views and return a list of latent representations (mu).

    At inference time this returns the posterior mean as the point
    estimate of the latent code for each view independently.

    Args:
        views: List of input tensors, one per view.

    Returns:
        List with one tensor of shape (batch_size, latent_dimensions)
        representing the shared posterior mean.
    """
    mu, _ = self._encode(views)
    return [mu]

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the ELBO loss (used during validation via BaseDeep).

Note: Reconstruction requires access to original views; for training the full ELBO is computed in :meth:training_step.

Parameters:

Name Type Description Default
representations list[Tensor]

Unused here; the method returns zero so that the validation step in BaseDeep does not crash.

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with "objective" set to 0.

Source code in cca_zoo/deep/_dvcca.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the ELBO loss (used during validation via BaseDeep).

    Note: Reconstruction requires access to original views; for
    training the full ELBO is computed in :meth:`training_step`.

    Args:
        representations: Unused here; the method returns zero so
            that the validation step in BaseDeep does not crash.
        independent_representations: Unused.

    Returns:
        Dictionary with ``"objective"`` set to 0.
    """
    return {"objective": torch.tensor(0.0)}

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Training step computing the full negative ELBO.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dictionary with key "views" (list of tensors).

required
batch_idx int

Batch index (unused).

required

Returns:

Type Description
Tensor

Scalar loss tensor (negative ELBO).

Source code in cca_zoo/deep/_dvcca.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Training step computing the full negative ELBO.

    Args:
        batch: Dictionary with key ``"views"`` (list of tensors).
        batch_idx: Batch index (unused).

    Returns:
        Scalar loss tensor (negative ELBO).
    """
    views = batch["views"]
    mu, log_var = self._encode(views)
    z = self._reparameterise(mu, log_var)
    reconstructions = self._decode(z)

    recon_loss = torch.stack(
        [F.mse_loss(x, r) for x, r in zip(views, reconstructions)]
    ).sum()
    # KL divergence: -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
    kl = -0.5 * torch.sum(1.0 + log_var - mu.pow(2) - log_var.exp())
    n = views[0].shape[0]
    kl = kl / n

    objective = recon_loss + kl
    loss_dict: dict[str, torch.Tensor] = {
        "objective": objective,
        "reconstruction": recon_loss,
        "kl": kl,
    }
    for k, v in loss_dict.items():
        self.log(
            f"train/{k}",
            v,
            on_step=False,
            on_epoch=True,
            batch_size=views[0].shape[0],
        )
    return objective

transform

transform(loader: DataLoader) -> list[np.ndarray]

Project all samples to the shared latent space via posterior mean.

Parameters:

Name Type Description Default
loader DataLoader

DataLoader yielding batches with a "views" key.

required

Returns:

Type Description
list[ndarray]

List with one numpy array of shape (n_samples, latent_dimensions).

Source code in cca_zoo/deep/_dvcca.py
@torch.no_grad()
def transform(self, loader: torch.utils.data.DataLoader) -> list[np.ndarray]:
    """Project all samples to the shared latent space via posterior mean.

    Args:
        loader: DataLoader yielding batches with a ``"views"`` key.

    Returns:
        List with one numpy array of shape (n_samples, latent_dimensions).
    """
    self.eval()
    all_mu: list[torch.Tensor] = []
    for batch in loader:
        views_dev = [v.to(self.device) for v in batch["views"]]
        mu, _ = self._encode(views_dev)
        all_mu.append(mu.cpu())
    mu_all = torch.cat(all_mu, dim=0)
    return [mu_all.numpy()]

DTCCA

DTCCA(latent_dimensions: int, encoders: list[Module], objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

Deep Tensor CCA.

Applies the tensor CCA loss to neural representations. The cross-moment tensor is formed from whitened latent codes, and the objective is the negative Frobenius norm of that tensor (serving as a differentiable proxy for the TCCA criterion).

Reference

Wong, H. S., et al. "Deep Tensor CCA for Multi-view Learning." IEEE Transactions on Big Data (2021).

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
objective Module | None

Ignored; the TCCA loss is always used. Accepted for API compatibility.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Ridge regularisation for whitening. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) enc3 = nn.Linear(6, 4) model = DTCCA(latent_dimensions=4, encoders=[enc1, enc2, enc3])

Source code in cca_zoo/deep/_dtcca.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    # Pass objective=None so DCCA creates CCALoss, but we override it
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        objective=None,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    # Override with TCCALoss regardless of what was passed
    self.objective = TCCALoss(eps=eps)

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the DTCCA loss via the tensor cross-moment.

Parameters:

Name Type Description Default
representations list[Tensor]

Encoded views from the current batch, each of shape (batch_size, latent_dimensions).

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with key "objective".

Source code in cca_zoo/deep/_dtcca.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the DTCCA loss via the tensor cross-moment.

    Args:
        representations: Encoded views from the current batch, each
            of shape (batch_size, latent_dimensions).
        independent_representations: Unused.

    Returns:
        Dictionary with key ``"objective"``.
    """
    return {"objective": self.objective(representations)}

SplitAE

SplitAE(latent_dimensions: int, encoders: list[Module], decoders: list[Module], lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: BaseDeep

Split Autoencoder baseline for multiview learning.

All views are encoded individually into a shared latent space. The concatenated representations are used to reconstruct each view via dedicated decoders. The loss is the sum of MSE reconstruction losses across all views::

L = sum_i MSE(x_i, decoder_i(cat(z_1, ..., z_V)))

This model serves as a reconstruction-based baseline that does not explicitly maximise correlation.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of each encoder's output. Decoders receive the concatenation of all encoder outputs, so their input size is n_views * latent_dimensions.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
decoders list[Module]

List of :class:torch.nn.Module objects. Each decoder's input dimension should be n_views * latent_dimensions.

required
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Unused; present for API consistency. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4)

Decoders receive 2 * 4 = 8 dimensional input

dec1 = nn.Linear(8, 10) dec2 = nn.Linear(8, 8) model = SplitAE( ... latent_dimensions=4, ... encoders=[enc1, enc2], ... decoders=[dec1, dec2], ... )

Source code in cca_zoo/deep/_splitae.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    decoders: list[nn.Module],
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.decoders = nn.ModuleList(decoders)

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Return a zero loss placeholder.

Reconstruction requires the original views; use the full training step for proper loss computation.

Parameters:

Name Type Description Default
representations list[Tensor]

Unused.

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with "objective" set to 0.

Source code in cca_zoo/deep/_splitae.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Return a zero loss placeholder.

    Reconstruction requires the original views; use the full
    training step for proper loss computation.

    Args:
        representations: Unused.
        independent_representations: Unused.

    Returns:
        Dictionary with ``"objective"`` set to 0.
    """
    return {"objective": torch.tensor(0.0)}

training_step

training_step(batch: dict[str, Any], batch_idx: int) -> torch.Tensor

Training step computing the reconstruction loss.

Parameters:

Name Type Description Default
batch dict[str, Any]

Dictionary with key "views" (list of tensors).

required
batch_idx int

Batch index (unused).

required

Returns:

Type Description
Tensor

Scalar loss tensor.

Source code in cca_zoo/deep/_splitae.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Training step computing the reconstruction loss.

    Args:
        batch: Dictionary with key ``"views"`` (list of tensors).
        batch_idx: Batch index (unused).

    Returns:
        Scalar loss tensor.
    """
    views = batch["views"]
    representations = self(views)
    reconstructions = self._decode(representations)

    recon_loss = torch.stack(
        [F.mse_loss(x, r) for x, r in zip(views, reconstructions)]
    ).sum()
    self.log(
        "train/objective",
        recon_loss,
        on_step=False,
        on_epoch=True,
        batch_size=views[0].shape[0],
    )
    return recon_loss

BarlowTwins

BarlowTwins(latent_dimensions: int, encoders: list[Module], lam: float = 0.005, objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

Barlow Twins self-supervised learning model.

Learns representations by encouraging the cross-correlation matrix between two views to be close to the identity. The loss has two components::

L = sum_i (1 - C_ii)^2 + lam * sum_{i != j} C_ij^2

where C is the cross-correlation matrix between batch-normalised representations of the two views. Batch normalisation is applied per-view before computing C.

Reference

Zbontar, J., et al. "Barlow twins: Self-supervised learning via redundancy reduction." ICML 2021.

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
lam float

Weight for the off-diagonal redundancy penalty. Default is 5e-3.

0.005
objective Module | None

Ignored; the Barlow Twins loss is fixed. Accepted for API compatibility.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Unused. Present for API compatibility. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = BarlowTwins(latent_dimensions=4, encoders=[enc1, enc2], lam=5e-3)

Source code in cca_zoo/deep/_barlowtwins.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    lam: float = 5e-3,
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        objective=objective,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.lam = lam
    self.bns = nn.ModuleList(
        [nn.BatchNorm1d(latent_dimensions, affine=False) for _ in encoders]
    )

forward

forward(views: list[Tensor]) -> list[torch.Tensor]

Encode views and apply batch normalisation.

Parameters:

Name Type Description Default
views list[Tensor]

List of input tensors, one per view.

required

Returns:

Type Description
list[Tensor]

List of batch-normalised latent tensors.

Source code in cca_zoo/deep/_barlowtwins.py
def forward(self, views: list[torch.Tensor]) -> list[torch.Tensor]:
    """Encode views and apply batch normalisation.

    Args:
        views: List of input tensors, one per view.

    Returns:
        List of batch-normalised latent tensors.
    """
    return [bn(enc(v)) for enc, bn, v in zip(self.encoders, self.bns, views)]

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the Barlow Twins loss for two batch-normalised views.

Parameters:

Name Type Description Default
representations list[Tensor]

List containing exactly two batch-normalised tensors, each of shape (batch_size, latent_dimensions).

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with keys "objective", "invariance", and

dict[str, Tensor]

"redundancy".

Source code in cca_zoo/deep/_barlowtwins.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the Barlow Twins loss for two batch-normalised views.

    Args:
        representations: List containing exactly two batch-normalised
            tensors, each of shape (batch_size, latent_dimensions).
        independent_representations: Unused.

    Returns:
        Dictionary with keys ``"objective"``, ``"invariance"``, and
        ``"redundancy"``.
    """
    z1, z2 = representations[0], representations[1]
    n = z1.shape[0]
    cross_cov = z1.T @ z2 / n

    invariance = torch.sum(torch.pow(1.0 - torch.diag(cross_cov), 2))
    # Off-diagonal entries
    mask = ~torch.eye(cross_cov.shape[0], dtype=torch.bool, device=cross_cov.device)
    redundancy = torch.sum(torch.pow(cross_cov[mask], 2))
    objective = invariance + self.lam * redundancy
    return {
        "objective": objective,
        "invariance": invariance,
        "redundancy": redundancy,
    }

VICReg

VICReg(latent_dimensions: int, encoders: list[Module], sim_coeff: float = 25.0, std_coeff: float = 25.0, cov_coeff: float = 1.0, objective: Module | None = None, lr: float = 0.001, max_epochs: int = 100, eps: float = 1e-06)

Bases: DCCA

Variance-Invariance-Covariance Regularization.

Three-term self-supervised objective that jointly encourages::

- Invariance: MSE similarity between the two views' representations.
- Variance: Standard deviation of each feature dimension >= 1.
- Covariance: Off-diagonal covariance close to zero.

The total loss is::

L = sim_coeff * MSE(z1, z2)
    + std_coeff * (hinge_var(z1) + hinge_var(z2))
    + cov_coeff * (off_diag_cov(z1) + off_diag_cov(z2))
Reference

Bardes, A., Ponce, J., & LeCun, Y. "VICReg: Variance-Invariance- Covariance Regularization for Self-Supervised Learning." arXiv:2105.04906 (2022).

Parameters:

Name Type Description Default
latent_dimensions int

Dimensionality of the shared latent space.

required
encoders list[Module]

List of :class:torch.nn.Module objects, one per view.

required
sim_coeff float

Weight for the invariance (MSE) term. Default is 25.0.

25.0
std_coeff float

Weight for the variance hinge term. Default is 25.0.

25.0
cov_coeff float

Weight for the covariance penalty term. Default is 1.0.

1.0
objective Module | None

Ignored; the VICReg loss is fixed. Accepted for API compatibility.

None
lr float

Learning rate. Default is 1e-3.

0.001
max_epochs int

Maximum training epochs. Default is 100.

100
eps float

Unused. Present for API compatibility. Default is 1e-6.

1e-06
Example

import torch import torch.nn as nn enc1 = nn.Linear(10, 4) enc2 = nn.Linear(8, 4) model = VICReg(latent_dimensions=4, encoders=[enc1, enc2])

Source code in cca_zoo/deep/_vicreg.py
def __init__(
    self,
    latent_dimensions: int,
    encoders: list[nn.Module],
    sim_coeff: float = 25.0,
    std_coeff: float = 25.0,
    cov_coeff: float = 1.0,
    objective: nn.Module | None = None,
    lr: float = 1e-3,
    max_epochs: int = 100,
    eps: float = 1e-6,
) -> None:
    super().__init__(
        latent_dimensions=latent_dimensions,
        encoders=encoders,
        objective=objective,
        lr=lr,
        max_epochs=max_epochs,
        eps=eps,
    )
    self.sim_coeff = sim_coeff
    self.std_coeff = std_coeff
    self.cov_coeff = cov_coeff

loss

loss(representations: list[Tensor], independent_representations: list[Tensor] | None = None) -> dict[str, torch.Tensor]

Compute the VICReg three-term loss.

Parameters:

Name Type Description Default
representations list[Tensor]

List of tensors, each of shape (batch_size, latent_dimensions). Currently only the first two views are used.

required
independent_representations list[Tensor] | None

Unused.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary with keys "objective", "sim_loss",

dict[str, Tensor]

"var_loss", and "cov_loss".

Source code in cca_zoo/deep/_vicreg.py
def loss(
    self,
    representations: list[torch.Tensor],
    independent_representations: list[torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Compute the VICReg three-term loss.

    Args:
        representations: List of tensors, each of shape
            (batch_size, latent_dimensions).  Currently only the
            first two views are used.
        independent_representations: Unused.

    Returns:
        Dictionary with keys ``"objective"``, ``"sim_loss"``,
        ``"var_loss"``, and ``"cov_loss"``.
    """
    z1, z2 = representations[0], representations[1]
    sim = _invariance_loss(z1, z2)
    var = _variance_loss(z1, z2)
    cov = _covariance_loss(z1, z2)
    objective = self.sim_coeff * sim + self.std_coeff * var + self.cov_coeff * cov
    return {
        "objective": objective,
        "sim_loss": sim,
        "var_loss": var,
        "cov_loss": cov,
    }

Objectives

objectives

Differentiable CCA loss functions for use with deep models.

CCALoss

CCALoss(eps: float = 1e-05)

Bases: Module

Andrew 2013 deep CCA correlation loss for two views.

Computes the negative sum of squared singular values of S11^{-1/2} S12 S22^{-1/2}, where S11, S12, S22 are empirical (co)variances with ridge regularisation. Minimising this loss maximises the total canonical correlation.

Reference

Andrew, G., et al. "Deep canonical correlation analysis." ICML 2013.

Parameters:

Name Type Description Default
eps float

Ridge regularisation added to within-view covariance matrices for numerical stability. Default is 1e-5.

1e-05
Example

import torch loss_fn = CCALoss(eps=1e-4) z1 = torch.randn(32, 4) z2 = torch.randn(32, 4) loss = loss_fn([z1, z2])

Source code in cca_zoo/deep/objectives.py
def __init__(self, eps: float = 1e-5) -> None:
    super().__init__()
    self.eps = eps

forward

forward(representations: list[Tensor]) -> torch.Tensor

Compute the CCA loss for a list containing exactly two views.

Parameters:

Name Type Description Default
representations list[Tensor]

List of two tensors, each of shape (batch_size, latent_dimensions).

required

Returns:

Type Description
Tensor

Scalar tensor: negative sum of squared canonical correlations.

Raises:

Type Description
ValueError

If the number of representations is not exactly 2.

Source code in cca_zoo/deep/objectives.py
def forward(self, representations: list[torch.Tensor]) -> torch.Tensor:
    """Compute the CCA loss for a list containing exactly two views.

    Args:
        representations: List of two tensors, each of shape
            (batch_size, latent_dimensions).

    Returns:
        Scalar tensor: negative sum of squared canonical correlations.

    Raises:
        ValueError: If the number of representations is not exactly 2.
    """
    if len(representations) != 2:
        raise ValueError(
            "CCALoss expects exactly 2 representations, "
            f"got {len(representations)}."
        )
    z1, z2 = representations
    n = z1.shape[0]
    d1, d2 = z1.shape[1], z2.shape[1]

    z1 = z1 - z1.mean(dim=0)
    z2 = z2 - z2.mean(dim=0)

    s11 = (z1.T @ z1) / (n - 1) + self.eps * torch.eye(
        d1, device=z1.device, dtype=z1.dtype
    )
    s22 = (z2.T @ z2) / (n - 1) + self.eps * torch.eye(
        d2, device=z2.device, dtype=z2.dtype
    )
    s12 = (z1.T @ z2) / (n - 1)

    s11_inv_sqrt = _inv_sqrtm(s11, self.eps)
    s22_inv_sqrt = _inv_sqrtm(s22, self.eps)

    t = s11_inv_sqrt @ s12 @ s22_inv_sqrt
    # Squared singular values = eigenvalues of T^T T
    tt = t.T @ t
    eigvals = torch.linalg.eigvalsh(tt)
    eigvals = torch.clamp(eigvals, min=0.0)
    return -eigvals.sum()

MCCALoss

MCCALoss(eps: float = 1e-05)

Bases: Module

Multiview extension of CCALoss that sums pairwise CCA losses.

For each ordered pair (i, j) with i < j, the pairwise CCALoss is computed and the results are summed. This encourages all views to be mutually correlated in the latent space.

Parameters:

Name Type Description Default
eps float

Ridge regularisation passed to each pairwise CCALoss. Default is 1e-5.

1e-05
Example

import torch loss_fn = MCCALoss(eps=1e-4) views = [torch.randn(32, 4) for _ in range(3)] loss = loss_fn(views)

Source code in cca_zoo/deep/objectives.py
def __init__(self, eps: float = 1e-5) -> None:
    super().__init__()
    self.eps = eps
    self._cca_loss = CCALoss(eps=eps)

forward

forward(representations: list[Tensor]) -> torch.Tensor

Compute the sum of pairwise CCA losses across all view pairs.

Parameters:

Name Type Description Default
representations list[Tensor]

List of tensors, each of shape (batch_size, latent_dimensions).

required

Returns:

Type Description
Tensor

Scalar tensor: sum of pairwise negative canonical correlations.

Source code in cca_zoo/deep/objectives.py
def forward(self, representations: list[torch.Tensor]) -> torch.Tensor:
    """Compute the sum of pairwise CCA losses across all view pairs.

    Args:
        representations: List of tensors, each of shape
            (batch_size, latent_dimensions).

    Returns:
        Scalar tensor: sum of pairwise negative canonical correlations.
    """
    n_views = len(representations)
    total = torch.tensor(0.0, device=representations[0].device)
    for i in range(n_views):
        for j in range(i + 1, n_views):
            total = total + self._cca_loss([representations[i], representations[j]])
    return total

GCCALoss

GCCALoss(eps: float = 1e-05)

Bases: Module

Generalised CCA loss for multiple views (GCCA objective).

Maximises the sum of squared correlations between each whitened view and a shared latent target T. In practice we minimise::

-tr( sum_i H_i^T T T^T H_i )

where H_i = X_i (X_i^T X_i + eps*I)^{-1/2} is the whitened representation of view i. T is obtained as the top-k eigenvectors of sum_i H_i H_i^T.

Parameters:

Name Type Description Default
eps float

Ridge regularisation for within-view covariance inversion. Default is 1e-5.

1e-05
Example

import torch loss_fn = GCCALoss(eps=1e-4) views = [torch.randn(32, 4) for _ in range(3)] loss = loss_fn(views)

Source code in cca_zoo/deep/objectives.py
def __init__(self, eps: float = 1e-5) -> None:
    super().__init__()
    self.eps = eps

forward

forward(representations: list[Tensor]) -> torch.Tensor

Compute the GCCA loss.

Parameters:

Name Type Description Default
representations list[Tensor]

List of tensors, each of shape (batch_size, latent_dimensions).

required

Returns:

Type Description
Tensor

Scalar tensor: negative total GCCA objective.

Source code in cca_zoo/deep/objectives.py
def forward(self, representations: list[torch.Tensor]) -> torch.Tensor:
    """Compute the GCCA loss.

    Args:
        representations: List of tensors, each of shape
            (batch_size, latent_dimensions).

    Returns:
        Scalar tensor: negative total GCCA objective.
    """
    n = representations[0].shape[0]
    whitened = []
    for z in representations:
        z_c = z - z.mean(dim=0)
        cov = (z_c.T @ z_c) / (n - 1) + self.eps * torch.eye(
            z_c.shape[1], device=z_c.device, dtype=z_c.dtype
        )
        whitened.append(z_c @ _inv_sqrtm(cov, self.eps))

    # M = sum_i H_i H_i^T, shape (n, n)
    m = torch.zeros(n, n, device=representations[0].device)
    for h in whitened:
        m = m + h @ h.T

    # Objective is trace of top singular values of M
    eigvals = torch.linalg.eigvalsh(m)
    k = representations[0].shape[1]
    top_eigvals = eigvals[-k:]
    return -top_eigvals.sum()

TCCALoss

TCCALoss(eps: float = 1e-05)

Bases: Module

Tensor CCA loss (proxy via Frobenius norm of cross-moment tensor).

Forms the cross-moment tensor M where M[d1, d2, ..., dV] = (1/n) sum_s prod_i H_i[s, d_i] for whitened representations H_i, then returns -||M||_F as a differentiable proxy for the tensor CCA objective.

Parameters:

Name Type Description Default
eps float

Ridge regularisation for whitening. Default is 1e-5.

1e-05
Example

import torch loss_fn = TCCALoss(eps=1e-4) views = [torch.randn(32, 4) for _ in range(3)] loss = loss_fn(views)

Source code in cca_zoo/deep/objectives.py
def __init__(self, eps: float = 1e-5) -> None:
    super().__init__()
    self.eps = eps

forward

forward(representations: list[Tensor]) -> torch.Tensor

Compute the tensor CCA loss.

Parameters:

Name Type Description Default
representations list[Tensor]

List of tensors, each of shape (batch_size, latent_dimensions).

required

Returns:

Type Description
Tensor

Scalar tensor: negative Frobenius norm of the cross-moment tensor.

Source code in cca_zoo/deep/objectives.py
def forward(self, representations: list[torch.Tensor]) -> torch.Tensor:
    """Compute the tensor CCA loss.

    Args:
        representations: List of tensors, each of shape
            (batch_size, latent_dimensions).

    Returns:
        Scalar tensor: negative Frobenius norm of the cross-moment tensor.
    """
    n = representations[0].shape[0]
    whitened = []
    for z in representations:
        z_c = z - z.mean(dim=0)
        cov = (z_c.T @ z_c) / (n - 1) + self.eps * torch.eye(
            z_c.shape[1], device=z_c.device, dtype=z_c.dtype
        )
        whitened.append(z_c @ _inv_sqrtm(cov, self.eps))

    # Build outer product tensor iteratively, shape (d, d, ..., d)
    m: torch.Tensor = whitened[0]
    for i in range(1, len(whitened)):
        el = whitened[i]
        for _ in range(len(m.shape) - 1):
            el = el.unsqueeze(1)
        m = m.unsqueeze(-1) * el

    # Average over samples
    m = m.mean(dim=0)
    return -torch.linalg.norm(m)