scButterfly.train_model_cite.Model.train

Model.train(loss_weight, train_id_r, train_id_a, validation_id_r, validation_id_a, R_encoder_lr=0.001, A_encoder_lr=0.001, R_decoder_lr=0.001, A_decoder_lr=0.001, R_translator_lr=0.001, A_translator_lr=0.001, translator_lr=0.001, discriminator_lr=0.005, R2R_pretrain_epoch=100, A2A_pretrain_epoch=100, lock_encoder_and_decoder=False, translator_epoch=200, patience=50, batch_size=64, r_loss=MSELoss(), a_loss=MSELoss(), d_loss=BCELoss(), output_path=None, seed=19193, kl_mean=True, R_pretrain_kl_warmup=50, A_pretrain_kl_warmup=50, translation_kl_warmup=50, load_model=None, logging_path=None)

Training for model. Some parameters need information about data, please see in Tutorial.

Parameters:
  • loss_weight (list) – list of loss weight for [r_loss, a_loss, d_loss, kl_div_R, kl_div_A, kl_div_all].

  • train_id_r (list) – list of RNA data cell ids for training.

  • train_id_a (list) – list of ADT data cell ids for training.

  • validation_id_r (list) – list of RNA data cell ids for validation.

  • validation_id_a (list) – list of ADT data cell ids for validation.

  • R_encoder_lr (float) – learning rate of RNA encoder, default 0.001.

  • A_encoder_lr (float) – learning rate of ADT encoder, default 0.001.

  • R_decoder_lr (float) – learning rate of RNA decoder, default 0.001.

  • A_decoder_lr (float) – learning rate of ADT decoder, default 0.001.

  • R_translator_lr (float) – learning rate of RNA pretrain translator, default 0.001.

  • A_translator_lr (float) – learning rate of ADT pretrain translator, default 0.001.

  • translator_lr (float) – learning rate of translator, default 0.001.

  • discriminator_lr (float) – learning rate of discriminator, default 0.005.

  • R2R_pretrain_epoch (int) – max epoch for pretrain RNA autoencoder, default 100.

  • A2A_pretrain_epoch (int) – max epoch for pretrain ADT autoencoder, default 100.

  • lock_encoder_and_decoder (bool) – lock the pretrained encoder and decoder or not, default False.

  • translator_epoch (int) – max epoch for train translator, default 200.

  • patience (int) – patience for loss on validation, default 50.

  • batch_size (int) – batch size for training and validation, default 64.

  • r_loss – loss function for RNA reconstruction, default nn.MSELoss(size_average=True).

  • a_loss – loss function for ADT reconstruction, default nn.MSELoss(size_average=True).

  • d_loss – loss function for discriminator, default nn.BCELoss(size_average=True).

  • output_path (str) – file path for model output, default None.

  • seed (int) – set up the random seed, default 19193.

  • kl_mean (bool) – size average for kl divergence or not, default True.

  • R_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in RNA pretrain, default 50.

  • A_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in ADT pretrain, default 50.

  • translation_kl_warmup (int) – epoch of linear weight warm up for kl divergence in translator pretrain, default 50.

  • load_model (str) – the path for loading model if needed, else set it None, default None.

  • logging_path (str) – the path for output process logging, if not save, set it None, default None.