Training Callbacks

class pytorch_wrapper.training_callbacks.AbstractCallback

Bases: abc.ABC

Objects of derived classes inject functionality in several points of the training process.

on_batch_end(training_context)

Called after a batch has been processed.

Parameters:training_context – Dict containing information regarding the training process.
on_batch_start(training_context)

Called just before processing a new batch.

Parameters:training_context – Dict containing information regarding the training process.
on_epoch_end(training_context)

Called at the end of an epoch.

Parameters:training_context – Dict containing information regarding the training process.
on_epoch_start(training_context)

Called at the beginning of a new epoch.

Parameters:training_context – Dict containing information regarding the training process.
on_evaluation_end(training_context)

Called at the end of the evaluation step.

Parameters:training_context – Dict containing information regarding the training process.
on_evaluation_start(training_context)

Called at the beginning of the evaluation step.

Parameters:training_context – Dict containing information regarding the training process.
on_training_end(training_context)

Called at the end of the training process.

Parameters:training_context – Dict containing information regarding the training process.
on_training_start(training_context)

Called at the beginning of the training process.

Parameters:training_context – Dict containing information regarding the training process.
post_backward_calculation(training_context)

Called just after backward is called.

Parameters:training_context – Dict containing information regarding the training process.
post_loss_calculation(training_context)

Called just after loss calculation.

Parameters:training_context – Dict containing information regarding the training process.
post_predict(training_context)

Called just after prediction during training time.

Parameters:training_context – Dict containing information regarding the training process.
pre_optimization_step(training_context)

Called just before the optimization step.

Parameters:training_context – Dict containing information regarding the training process.
class pytorch_wrapper.training_callbacks.EarlyStoppingCriterionCallback(patience, evaluation_data_loader_key, evaluator_key, tmp_best_state_filepath)

Bases: pytorch_wrapper.training_callbacks.StoppingCriterionCallback

Stops the training process if the results do not get better for a number of epochs.

Parameters:
  • patience – How many epochs to forgive deteriorating results.
  • evaluation_data_loader_key – Key of the data-loader dict (provided as an argument to the train method of System) that corresponds to the data-set that the early stopping method considers.
  • evaluator_key – Key of the evaluators dict (provided as an argument to the train method of System) that corresponds to the evaluator that the early stopping method considers.
  • tmp_best_state_filepath – Path where the state of the best so far model will be saved.
on_evaluation_end(training_context)

Called at the end of the evaluation step.

Parameters:training_context – Dict containing information regarding the training process.
on_training_end(training_context)

Called at the end of the training process.

Parameters:training_context – Dict containing information regarding the training process.
on_training_start(training_context)

Called at the beginning of the training process.

Parameters:training_context – Dict containing information regarding the training process.
class pytorch_wrapper.training_callbacks.NumberOfEpochsStoppingCriterionCallback(nb_of_epochs)

Bases: pytorch_wrapper.training_callbacks.StoppingCriterionCallback

Stops the training process after a number of epochs.

Parameters:nb_of_epochs – Number of epochs to train.
on_epoch_end(training_context)

Called at the end of an epoch.

Parameters:training_context – Dict containing information regarding the training process.
class pytorch_wrapper.training_callbacks.StoppingCriterionCallback

Bases: pytorch_wrapper.training_callbacks.AbstractCallback