scButterfly.butterfly.Butterfly.train_model

Butterfly.train_model(R_encoder_lr=0.001, A_encoder_lr=0.001, R_decoder_lr=0.001, A_decoder_lr=0.001, R_translator_lr=0.001, A_translator_lr=0.001, translator_lr=0.001, discriminator_lr=0.005, R2R_pretrain_epoch=100, A2A_pretrain_epoch=100, lock_encoder_and_decoder=False, translator_epoch=200, patience=50, batch_size=64, r_loss=MSELoss(), a_loss=BCELoss(), d_loss=BCELoss(), loss_weight=[1, 2, 1], output_path=None, seed=19193, kl_mean=True, R_pretrain_kl_warmup=50, A_pretrain_kl_warmup=50, translation_kl_warmup=50, load_model=None, logging_path=None)

Training for model.

Parameters:
  • R_encoder_lr (float) – learning rate of RNA encoder, default 0.001.

  • A_encoder_lr (float) – learning rate of ATAC encoder, default 0.001.

  • R_decoder_lr (float) – learning rate of RNA decoder, default 0.001.

  • A_decoder_lr (float) – learning rate of ATAC decoder, default 0.001.

  • R_translator_lr (float) – learning rate of RNA pretrain translator, default 0.001.

  • A_translator_lr (float) – learning rate of ATAC pretrain translator, default 0.001.

  • translator_lr (float) – learning rate of translator, default 0.001.

  • discriminator_lr (float) – learning rate of discriminator, default 0.005.

  • R2R_pretrain_epoch (int) – max epoch for pretrain RNA autoencoder, default 100.

  • A2A_pretrain_epoch (int) – max epoch for pretrain ATAC autoencoder, default 100.

  • lock_encoder_and_decoder (bool) – lock the pretrained encoder and decoder or not, default False.

  • translator_epoch (int) – max epoch for train translator, default 200.

  • patience (int) – patience for loss on validation, default 50.

  • batch_size (int) – batch size for training and validation, default 64.

  • r_loss – loss function for RNA reconstruction, default nn.MSELoss(size_average=True).

  • a_loss – loss function for ATAC reconstruction, default nn.BCELoss(size_average=True).

  • d_loss – loss function for discriminator, default nn.BCELoss(size_average=True).

  • loss_weight (list) – list of loss weight for [r_loss, a_loss, d_loss], default [1, 2, 1].

  • output_path (str) – file path for model output, default None.

  • seed (int) – set up the random seed, default 19193.

  • kl_mean (bool) – size average for kl divergence or not, default True.

  • R_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in RNA pretrain, default 50.

  • A_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in ATAC pretrain, default 50.

  • translation_kl_warmup (int) – epoch of linear weight warm up for kl divergence in translator pretrain, default 50.

  • load_model (str) – the path for loading model if needed, else set it None, default None.

  • logging_path (str) – the path for output process logging, if not save, set it None, default None.