Abstract model class#

class scconfluence.base_module.BaseModule#

Bases: LightningModule

Base class for all models. It is a subclass of pytorch_lightning.LightningModule.

configure_optimizers() Optimizer#

Configure the optimizer used for training.

Returns:

the optimizer object.

fit(save_path: str | Path, use_cuda: bool = True, ratio_val: float = 0.2, batch_size: int = 512, pin_memory: bool = True, num_workers: int = 0, max_epochs: int = 1000, lr: float = 0.003, early_stopping: bool = True, es_metric: str = 'val_full_loss', patience: int = 40, es_mode: Literal['min', 'max'] = 'min', test_mode: bool = False, save_models: bool = False, **trainer_kwargs)#

Train the model on the dataset witch which it was initialized.

Parameters:
  • save_path – path in which logs and model checkpoints will be saved.

  • use_cuda – whether to use GPU acceleration if cuda is available.

  • ratio_val – ratio of the dataset to be used for the validation split.

  • batch_size – size of the mini-batches used for training. Not to be confused with the experimental batches.

  • pin_memory – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.

  • num_workers – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

  • max_epochs – max number of epochs used for training. Convergence is often reached before this number.

  • lr – learning rate,used for training.

  • early_stopping – whether to use early stopping.

  • es_metric – which logged metric to use for early stopping.

  • patience – Number of epochs to wait before stopping training if the monitored early stopping metric does not improve.

  • es_mode – “min” or “max depending on whether the early stopping metric should be minimized or maximized.

  • test_mode – when testing, the model will only be trained for 5 epochs with only 10 mini-batches per epoch.

  • save_models – whether to save the checkpoint of the best model according to the early stopping metric.

  • trainer_kwargs – additional arguments to be passed to the pytorch_lightning.Trainer.

forward(x, return_loss=False)#

Forward pass of the model. It returns the reconstruction and the latent embeddings of the input data and optionally the full loss.

Parameters:
  • x – input mini-batch

  • return_loss – whether to return the full loss or not.

abstract generative(inference_output)#
get_latent(use_cuda: bool = True, batch_size: int = 512, pin_memory: bool = True, num_workers: int = 0) DataFrame#

Get the latent embeddings of all cells of the dataset from the trained model.

Parameters:
  • use_cuda – whether to use GPU acceleration if available.

  • batch_size – size of the mini-batches used for training. Not to be confused with the experimental batches.

  • pin_memory – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.

  • num_workers – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

Returns:

dataframe with the latent embeddings of all cells of the dataset, indexed by their observation names.

abstract inference(x)#
abstract latent_batch(batch)#
log_all_metrics(split: Literal['train', 'val'])#

Log all metrics in self.metrics_to_log for the given split.

Parameters:

split – whether the metrics are being logged for a training or validation mini-batch.

abstract log_latent_norms(inference_dic, return_log=False)#
abstract loss(x, inference_output, generative_output, reduce=True)#
optimizer_zero_grad(epoch, batch_idx, optimizer)#

Resets the gradients of the optimized tensors. We set the gradients to None to improve the performance.

abstract reset_log_metrics()#
train_val_step(batch, split)#

Wrapper to perform a training or validation step. It returns the loss of the mini-batch.

Parameters:
  • batch – input data mini-batch

  • split – whether the mini-batch is part of the training or validation set.

Returns:

loss of the model on the mini-batch.

training_step(batch, batch_idx)#

Use the train_val_step method to perform a training step.

Returns:

loss of the model on the mini-batch.

validation_step(batch, batch_idx)#

Use the train_val_step method to perform a validation step.

Returns:

loss of the model on the mini-batch.