Unimodal autoencoder#

class scconfluence.unimodal.AutoEncoder(adata: AnnData, modality: str, rep_in: str | None = None, rep_out: str | None = None, batch_key: str | None = None, n_hidden: int = 64, n_latent: int = 16, type_loss: Literal['l2', 'zinb', 'nb', 'poisson', 'binary'] = 'l2', reconstruction_weight: float = 1.0, avg_feat: bool = True, n_layers_enc: int = 3, n_layers_dec: int = 2, use_batch_norm_enc: Literal[None, 'ds', 'standard'] | None = None, use_batch_norm_dec: Literal[None, 'ds', 'standard'] | None = None, dropout_rate: float = 0.0, var_eps: float = 0.0001, deeply_inject_covariates_enc: bool = True, deeply_inject_covariates_dec: bool = True, positive_out: bool = False)#

Bases: BaseModule

Autoencoder model for unimodal single-cell data. This is a generic model designed to handle different types of data (e.g. RNA, ATAC, etc.) and different types of losses (e.g. L2, ZINB, etc.). For high dimensional modalities, such as RNA and ATAC, the model can use input data different from the one used to assess the quality of the reconstruction. For example, the input data can be the PCA projection of the normalized counts and the output data can be the raw counts.

Parameters:
  • adata – AnnData object containing the data.

  • modality – string indicating the modality of the data.

  • rep_in – string indicating the entry of the Anndata where to look for the input data, i.e. the data used as input of the encoder. If not None, the input data will be extracted from the obsm field of the AnnData object. If None, the input data is assumed to be the X field of the AnnData object.

  • rep_out – string indicating the entry of the Anndata where to look for the output data, i.e. the data used to compare with the output of the decoder. If not None, the output data will be extracted from the layers field of the AnnData object. If None, the output data is assumed to be the X field of the AnnData object.

  • batch_key – If the data is not composed of multiple experimental batches than this should be set to None. Otherwise, this string indicates the entry in the obs field of the Anndata where to look for the batch information.

  • n_hidden – number of hidden units in the encoder and decoder.

  • n_latent – number of latent dimensions.

  • type_loss – string indicating the type of loss to use. It can be “l2”, “zinb”, “nb”, “poisson” or “binary”.

  • reconstruction_weight – weight of the reconstruction loss.

  • avg_feat – if True, the reconstruction loss is averaged over the features, otherwise it is summed.

  • n_layers_enc – number of layers in the encoder.

  • n_layers_dec – number of layers in the decoder.

  • use_batch_norm_enc – if not None, it indicates the type of batch normalization to use in the encoder.

  • use_batch_norm_dec – if not None, it indicates the type of batch normalization to use in the decoder.

  • dropout_rate – dropout rate to use in the encoder and decoder.

  • var_eps – small positive value to add to the variance of the posterior distribution.

  • deeply_inject_covariates_enc – if True, the batch_index is deeply injected in the encoder.

  • deeply_inject_covariates_dec – if True, the batch_index is deeply injected in the decoder.

  • positive_out – if True, the output of the decoder is forced to be positive.

generative(inference_output: dict[str, torch.Tensor]) dict[str, torch.Tensor]#

Generative step of the model which consists in a forward pass through the model’s decoder.

Parameters:

inference_output – output of the encoder (inference step).

Returns:

the output of the decoder.

inference(x: dict[str, torch.Tensor]) dict[str, torch.Tensor]#

Encoding step of the model which consists in a forward pass through the model’s encoder.

Parameters:

x – input mini-batch of data.

Returns:

the latent embeddings (parameters of the posterior distribution), the library size and the batch indices.

latent_batch(x: dict[str, torch.Tensor]) tuple[numpy.ndarray, numpy.ndarray]#

Get latent embeddings for a mini-batch of data. Used for prediction after the training of the model.

Parameters:

x – mini-batch of input data.

Returns:

latent embeddings and their corresponding observation names.

log_latent_norms(inference_dic: dict[str, torch.Tensor], return_log: bool = False)#

Log the norm of the latent embeddings.

Parameters:
  • inference_dic – output of the encoder (inference step).

  • return_log – whether to return the log or not.

Returns:

if return_log is true, return the log to be printed.

loss(x: dict[str, torch.Tensor], inference_output: dict[str, torch.Tensor], generative_output: dict[str, torch.Tensor], reduce: bool = True) dict[str, torch.Tensor]#

Compute all loss terms of the model. For an AutoEncoder model, this means only the reconstruction loss but for future developments other losses can be added to this model.

Parameters:
  • x – mini-batch input data.

  • inference_output – output of the encoder (inference step).

  • generative_output – output of the decoder (generative step).

  • reduce – whether to reduce the cell reconstruction losses to a single scalar or not.

Returns:

a dictionary containing all loss terms used to train the model.

predict_step(batch, batch_idx)#

Predict the latent embeddings for a mini-batch of data.

Returns:

latent embeddings and their corresponding observation names.

reconstr_loss(x: dict[str, torch.Tensor], inference_output: dict[str, torch.Tensor], generative_output: dict[str, torch.Tensor], reduce: bool = True) Tensor#

Compute the reconstruction loss.

Parameters:
  • x – input mini-batch of data.

  • inference_output – output of the encoder (inference step).

  • generative_output – output of the decoder (generative step).

  • reduce – whether to reduce the cell reconstruction losses to a single scalar or not.

Returns:

the reconstruction loss.

reset_log_metrics()#

Reset the list of metrics to log.