Trainer module
In this module, the ensemble is trained on the whole available dataset. This is convenient for active learning purposes since the points are expensive to obtain and the datasets are usually very small. The ensemble weights yielding the smallest loss are saved, which helps in reducing the oversampling over the existing training points. This can be further refined by expelling the worst performing ensemble members according to the training loss by adjusting num_expel_NNs parameter. The training of the ensemble in the context of active learning can be significantly accelerated in each iteration by invoking warm starting process.
num_epochs (int)
Number of epochs used in the training of all the ensemble members.
Default:
800
num_expel_NNs (int)
The ensemble members are sorted according to the final training loss. According to that loss, the
num_expel_NNs
worst performing ones are removed.Default:
0
learning_rate (float)
Learning rate hyperparameter used in the loss minimization when cold starting.
Default:
0.006
scale_grad_loss (float)
Hyperparameter used to scale the loss induced by the gradients with respect to the input parameters.
Default:
0.5
optimizer (str)
Optimizer used for the minimization of the loss during the ensemble training process. AdamW optimizer has the correct implementation of weight_decay regularization.
Default:
"Adam"
Choices:"AdamW"
,"Adam"
.
weight_decay (float)
Weight decay hyperparameter in optimization with AdamW optimizer.
Default:
0.0
loss_function (str)
Loss function which is minimized during ensemble training.
Default:
"MSE"
Choices:"MSE"
,"L1"
.
train_mean_grad (bool)
If True, the gradient averaged over all ensemble members is trained to match observed gradients. This leads to a loss in accuracy, because each ensemble member can have a gradient different from the observation. But it speeds up the training time significantly.
Default:
true
batch_size_train (int)
Batch size for training of the ensemble. If not explicitly stated, it is determined during the optimization run so that all the training data is processed in one batch.
Default:
None
save_history_path (str)
An absolute path where the training history shall be stored.
Default:
None
num_epochs_warm (int)
Number of epochs used in the training of all the ensemble members when warm starting is applied. It is usually required much less epochs when warm starting due to the good initial point in the weight space.
Default:
200
learning_rate_warm (float)
Learning rate hyperparameter used in the loss minimization when warm starting.
Default:
0.004
warm_start (bool)
This flag enables warm starting. The weights from the previous iteration of the active learning algorithm are reused resulting in a significantly faster training. The shrink and perturb method by JT Ash is implemented.
Default:
false
shrink (float)
Warm start hyperparameter used to shrink the weights.
Default:
0.9
perturb (float)
Warm start hyperparameter used to perturb the weights.
Default:
0.1
cold_start_interval (int)
Number of iterations after which cold (fresh) ensemble training is invoked.
Default:
20