scButterfly.butterfly.Butterfly.train_model¶
- Butterfly.train_model(R_encoder_lr=0.001, A_encoder_lr=0.001, R_decoder_lr=0.001, A_decoder_lr=0.001, R_translator_lr=0.001, A_translator_lr=0.001, translator_lr=0.001, discriminator_lr=0.005, R2R_pretrain_epoch=100, A2A_pretrain_epoch=100, lock_encoder_and_decoder=False, translator_epoch=200, patience=50, batch_size=64, r_loss=MSELoss(), a_loss=BCELoss(), d_loss=BCELoss(), loss_weight=[1, 2, 1], output_path=None, seed=19193, kl_mean=True, R_pretrain_kl_warmup=50, A_pretrain_kl_warmup=50, translation_kl_warmup=50, load_model=None, logging_path=None)¶
Training for model.
- Parameters:
R_encoder_lr (float) – learning rate of RNA encoder, default 0.001.
A_encoder_lr (float) – learning rate of ATAC encoder, default 0.001.
R_decoder_lr (float) – learning rate of RNA decoder, default 0.001.
A_decoder_lr (float) – learning rate of ATAC decoder, default 0.001.
R_translator_lr (float) – learning rate of RNA pretrain translator, default 0.001.
A_translator_lr (float) – learning rate of ATAC pretrain translator, default 0.001.
translator_lr (float) – learning rate of translator, default 0.001.
discriminator_lr (float) – learning rate of discriminator, default 0.005.
R2R_pretrain_epoch (int) – max epoch for pretrain RNA autoencoder, default 100.
A2A_pretrain_epoch (int) – max epoch for pretrain ATAC autoencoder, default 100.
lock_encoder_and_decoder (bool) – lock the pretrained encoder and decoder or not, default False.
translator_epoch (int) – max epoch for train translator, default 200.
patience (int) – patience for loss on validation, default 50.
batch_size (int) – batch size for training and validation, default 64.
r_loss – loss function for RNA reconstruction, default nn.MSELoss(size_average=True).
a_loss – loss function for ATAC reconstruction, default nn.BCELoss(size_average=True).
d_loss – loss function for discriminator, default nn.BCELoss(size_average=True).
loss_weight (list) – list of loss weight for [r_loss, a_loss, d_loss], default [1, 2, 1].
output_path (str) – file path for model output, default None.
seed (int) – set up the random seed, default 19193.
kl_mean (bool) – size average for kl divergence or not, default True.
R_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in RNA pretrain, default 50.
A_pretrain_kl_warmup (int) – epoch of linear weight warm up for kl divergence in ATAC pretrain, default 50.
translation_kl_warmup (int) – epoch of linear weight warm up for kl divergence in translator pretrain, default 50.
load_model (str) – the path for loading model if needed, else set it None, default None.
logging_path (str) – the path for output process logging, if not save, set it None, default None.