Abstract model class#
- class scconfluence.base_module.BaseModule#
Bases:
LightningModuleBase class for all models. It is a subclass of pytorch_lightning.LightningModule.
- configure_optimizers() Optimizer#
Configure the optimizer used for training.
- Returns:
the optimizer object.
- fit(save_path: str | Path, use_cuda: bool = True, ratio_val: float = 0.2, batch_size: int = 512, pin_memory: bool = True, num_workers: int = 0, max_epochs: int = 1000, lr: float = 0.003, early_stopping: bool = True, es_metric: str = 'val_full_loss', patience: int = 40, es_mode: Literal['min', 'max'] = 'min', test_mode: bool = False, save_models: bool = False, **trainer_kwargs)#
Train the model on the dataset witch which it was initialized.
- Parameters:
save_path – path in which logs and model checkpoints will be saved.
use_cuda – whether to use GPU acceleration if cuda is available.
ratio_val – ratio of the dataset to be used for the validation split.
batch_size – size of the mini-batches used for training. Not to be confused with the experimental batches.
pin_memory – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_workers – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
max_epochs – max number of epochs used for training. Convergence is often reached before this number.
lr – learning rate,used for training.
early_stopping – whether to use early stopping.
es_metric – which logged metric to use for early stopping.
patience – Number of epochs to wait before stopping training if the monitored early stopping metric does not improve.
es_mode – “min” or “max depending on whether the early stopping metric should be minimized or maximized.
test_mode – when testing, the model will only be trained for 5 epochs with only 10 mini-batches per epoch.
save_models – whether to save the checkpoint of the best model according to the early stopping metric.
trainer_kwargs – additional arguments to be passed to the pytorch_lightning.Trainer.
- forward(x, return_loss=False)#
Forward pass of the model. It returns the reconstruction and the latent embeddings of the input data and optionally the full loss.
- Parameters:
x – input mini-batch
return_loss – whether to return the full loss or not.
- abstract generative(inference_output)#
- get_latent(use_cuda: bool = True, batch_size: int = 512, pin_memory: bool = True, num_workers: int = 0) DataFrame#
Get the latent embeddings of all cells of the dataset from the trained model.
- Parameters:
use_cuda – whether to use GPU acceleration if available.
batch_size – size of the mini-batches used for training. Not to be confused with the experimental batches.
pin_memory – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.
num_workers – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
- Returns:
dataframe with the latent embeddings of all cells of the dataset, indexed by their observation names.
- abstract inference(x)#
- abstract latent_batch(batch)#
- log_all_metrics(split: Literal['train', 'val'])#
Log all metrics in self.metrics_to_log for the given split.
- Parameters:
split – whether the metrics are being logged for a training or validation mini-batch.
- abstract log_latent_norms(inference_dic, return_log=False)#
- abstract loss(x, inference_output, generative_output, reduce=True)#
- optimizer_zero_grad(epoch, batch_idx, optimizer)#
Resets the gradients of the optimized tensors. We set the gradients to None to improve the performance.
- abstract reset_log_metrics()#
- train_val_step(batch, split)#
Wrapper to perform a training or validation step. It returns the loss of the mini-batch.
- Parameters:
batch – input data mini-batch
split – whether the mini-batch is part of the training or validation set.
- Returns:
loss of the model on the mini-batch.
- training_step(batch, batch_idx)#
Use the train_val_step method to perform a training step.
- Returns:
loss of the model on the mini-batch.
- validation_step(batch, batch_idx)#
Use the train_val_step method to perform a validation step.
- Returns:
loss of the model on the mini-batch.