Loss Wrappers¶
-
class
pytorch_wrapper.loss_wrappers.
AbstractLossWrapper
¶ Bases:
abc.ABC
Objects of derived classes are used to wrap a loss module providing an interface used by the System class.
-
calculate_loss
(batch, output, training_context, last_activation=None)¶ Calculates the loss for a single batch.
Parameters: - batch – Dict that contains all information needed by the loss wrapper.
- output – Output of the model.
- training_context – Dict containing information regarding the training process.
- last_activation – Last activation provided to the System.
Returns: Output of the loss function/module.
-
-
class
pytorch_wrapper.loss_wrappers.
GenericPointWiseLossWrapper
(loss, model_output_key=None, batch_target_key='target', perform_last_activation=False)¶ Bases:
pytorch_wrapper.loss_wrappers.AbstractLossWrapper
Adapter that wraps a pointwise loss module.
Parameters: - loss – Loss module.
- model_output_key – Key where the dict returned by the model contains the actual predictions. Leave None if the model returns only the predictions.
- batch_target_key – Key where the dict (batch) contains the target values.
- perform_last_activation – Whether to perform the last_activation.
-
calculate_loss
(output, batch, training_context, last_activation=None)¶ Calculates the loss for a single batch.
Parameters: - batch – Dict that contains all information needed by the loss wrapper.
- output – Output of the model.
- training_context – Dict containing information regarding the training process.
- last_activation – Last activation provided to the System.
Returns: Output of the loss function/module.
-
class
pytorch_wrapper.loss_wrappers.
TokenLabelingGenericPointWiseLossWrapper
(loss, batch_input_sequence_length_idx, batch_input_key='input', model_output_key=None, batch_target_key='target', perform_last_activation=False, end_padded=True)¶ Bases:
pytorch_wrapper.loss_wrappers.AbstractLossWrapper
Adapter that wraps a pointwise loss module. It is used in token labeling tasks in order to flat the output and target while discarding invalid values due to padding.
Parameters: - loss – Loss module.
- batch_input_sequence_length_idx – The index of the input list where the lengths of the sequences can be found.
- batch_input_key – Key of the Dicts returned by the Dataloader objects that corresponds to the input of the model.
- model_output_key – Key where the dict returned by the model contains the actual predictions. Leave None if the model returns only the predictions.
- batch_target_key – Key where the dict (batch) contains the target values.
- perform_last_activation – Whether to perform the last_activation.
- end_padded – Whether the sequences are end-padded.
-
calculate_loss
(output, batch, training_context, last_activation=None)¶ Calculates the loss for a single batch.
Parameters: - batch – Dict that contains all information needed by the loss wrapper.
- output – Output of the model.
- training_context – Dict containing information regarding the training process.
- last_activation – Last activation provided to the System.
Returns: Output of the loss function/module.