ScConfluence model#
- class scconfluence.model.ScConfluence(mdata: MuData, unimodal_aes: dict[str, scconfluence.unimodal.AutoEncoder], mass: float = 0.5, reach: float = 0.3, blur: float = 0.01, iot_loss_weight: float = 0.01, sinkhorn_loss_weight: float = 0.1)#
Bases:
BaseModuleThe ScConfluence model aims to learn a common latent space for multiple data modalities from unpaired measurements.
- Parameters:
mdata – the input data
unimodal_aes – dictionary of AutoEncoder objects for each modality
mass – mass parameter for the IOT loss which corresponds to the proportion of the cells (between 0 and 1) that will be matched across modalities in each mini-batch during training, decreasing this parameter will lead to more robustness towards modality-specific cell populations which should not be aligned with cells from other modalities (even for datasets where the same cell populations are roughly expected in the different modalities 0.5 is a good default as transporting less mass is less problematic than transporting too much).
reach – reach parameter which controls the unbalancedness of the Sinkhorn regularization. Lower values will lead to a more unbalanced sinkhorn divergence. When trying to enforce a more complete mixing of cells from different modalities in the latent space, a higher reach value can be used. Only values between 0.1 and 5. are recommended.
blur – blur parameter for the Sinkhorn regularization. This parameter controls the strength of the entropic term in the sinkhorn regularization. Increasing this parameter makes the computation of the sinkhorn term faster but less accurate.
iot_loss_weight – weight of the IOT loss in the final loss. It can be set higher than its default value of 0.01 in situations where the cost matrix between modalities is assumed to be of very high quality (e.g. NOT when comparing scRNA expressions and scATAC-derived gene activities). Only values between 0.005 and 0.1 are recommended.
sinkhorn_loss_weight – weight of the Sinkhorn regularization term in the final loss. Setting it higher than its default value of 0.1 can be useful when trying to enforce a more complete mixing of cells from different modalities. On the opposite, when integrating more than two modalities, it can be useful to set it lower (e.g. to 0.3) since less regularization will be required to align the modalities (3 pair-wise terms are enforced). Only values between 0.01 and 0.5 are recommended.
- generative(inference_output: dict[str, dict[str, torch.Tensor]]) dict[str, dict[str, torch.Tensor]]#
Perform generative step on the inference results. The results of the generative step on each modality done by its corresponding AutoEncoder’s decoder are returned in a dictionary with the same structure as the inference_output.
- Parameters:
inference_output – inference results on mini_batch. A dictionary of dictionaries, where the first level correspond to the modalities and the second level correspond to the data for each modality.
- Returns:
generative results for each modality, contains the reconstructed data (decodings of the embeddings).
- get_imputation(impute_from: str, impute_to: str, to_batch: str | None = None, use_cuda: bool = True, batch_size: int = 512, **dl_kwargs) DataFrame#
Perform imputation from one modality to another on the whole dataset after the end of the training.
- Parameters:
impute_from – the modality of cells for which we want to perform the imputation.
impute_to – the modality of the features we want to impute.
to_batch – the batch of the modality impute_to to use for the imputation. If None, use the first batch.
use_cuda – whether to use GPU acceleration if cuda is available.
batch_size – size of the mini-batches used for the imputation.
dl_kwargs – additional keyword arguments for the DataLoader.
- Returns:
The imputations a dataFrame indexed by their observation names from the input MuData.
- get_iot_loss(z_1: Tensor, z_2: Tensor, c_cross: Tensor) Tensor#
Compute the IOT loss between two sets of latent embeddings.
- Parameters:
z_1 – latent embeddings from the first modality.
z_2 – latent embeddings from the second modality.
c_cross – cost matrix between the two modalities. Rows correspond to z_1 and columns to z_2.
- Returns:
The IOT loss.
- get_sinkhorn_reg(z_1: Tensor, z_2: Tensor) Tensor#
Compute the Sinkhorn regularization term between two sets of latent embeddings.
- Parameters:
z_1 – latent embeddings from the first modality.
z_2 – latent embeddings from the second modality.
- Returns:
the unbalanced Sinkhorn regularization term.
- imputation_batch(x, imp_from: str, imp_to: str, to_batch: str | None) tuple[numpy.ndarray, numpy.ndarray]#
Perform imputation from one modality to another on a mini-batch after the end of the training. The imputation is done by first encoding in the latent space cells from the modality imp_from, and then decoding these embeddings into the modality imp_to.
- Parameters:
x – input data mini_batch.
imp_from – the modality of cells for which we want to perform the imputation.
imp_to – the modality of the features we want to impute.
to_batch – the batch of the modality imp_to to use for the imputation. If None, use the first batch.
- Returns:
The imputations and their corresponding observation names to ensure proper indexing of the results.
- inference(x: dict[str, dict[str, torch.Tensor]]) dict[str, dict[str, torch.Tensor]]#
Get encodings of the input data. The results of the embedding on each modality’s cells with its corresponding AutoEncoder’s encoder are returned in a dictionary with the same structure as the input data.
- Parameters:
x – input data mini_batch. A dictionary of dictionaries, where the first level correspond to the modalities and the second level correspond to the data for each modality.
- Returns:
inference results (latent embeddings, …) for each modality
- latent_batch(x: dict[str, dict[str, torch.Tensor]]) tuple[numpy.ndarray, numpy.ndarray]#
Get the latent embeddings for the input data. Used for prediction after the training of the model.
- Parameters:
x – input data mini_batch. A dictionary of dictionaries, where the first level correspond to the modalities and the second level correspond to the data for each modality.
- Returns:
The latent embeddings and their corresponding observation names to ensure proper indexing of the results.
- log_latent_norms(inference_dic: dict[str, dict[str, torch.Tensor]], return_log: bool = False)#
Log the norms of the latent embeddings for each modality.
- Parameters:
inference_dic – The results of the encoding of input data for each modality in a dictionary with the same structure as the input data. Contains the latent embeddings.
return_log – unused parameter
- loss(x: dict[str, dict[str, torch.Tensor]], inference_output: dict[str, dict[str, torch.Tensor]], generative_output: dict[str, dict[str, torch.Tensor]], reduce: bool = True) dict[str, torch.Tensor]#
Compute all the losses for the model.
- Parameters:
x – input data mini_batch. A dictionary of dictionaries, where the first level correspond to the modalities and the second level correspond to the data for each modality.
inference_output – The results of the encoding of input data for each modality in a dictionary with the same structure as the input data. Contains the latent embeddings.
generative_output – The results of the decoding of the inferred embeddings for each modality in a dictionary with the same structure as the input data. Contains the reconstruction of the data.
reduce – whether to reduce the cell reconstruction losses to a single scalar or not.
- Returns:
The weighted sum of all sum terms which constitute the final loss.
- predict_step(batch: dict[str, dict[str, torch.Tensor]], batch_idx) dict[str, numpy.ndarray]#
Perform a prediction step on a mini-batch. The prediction step can be either imputation or latent embedding.
- Parameters:
batch – the input data mini-batch. Not to be confused with cell experimental batches.
batch_idx – index of the mini-batch in the dataset. Not to be confused with cell experimental batches. Not used in this function but required by PyTorch Lightning.
- Returns:
dictionary with results of predictions and their corresponding observation names to ensure proper indexing.
- reset_log_metrics()#
Reset the metrics to log at the beginning of each new optimization step on a mini-batch.
- scconfluence.model.check_aes(mdata: MuData, unimodal_aes: dict[str, scconfluence.unimodal.AutoEncoder])#
Check that the AutoEncoders and the mdata object are consistent, raises a ValueError if not.
- Parameters:
mdata – the input data
unimodal_aes – dictionary of AutoEncoder objects for each modality