Loss Wrappers

class pytorch_wrapper.loss_wrappers.AbstractLossWrapper

Bases: abc.ABC

Objects of derived classes are used to wrap a loss module providing an interface used by the System class.

calculate_loss(batch, output, training_context, last_activation=None)

Calculates the loss for a single batch.

Parameters:
  • batch – Dict that contains all information needed by the loss wrapper.
  • output – Output of the model.
  • training_context – Dict containing information regarding the training process.
  • last_activation – Last activation provided to the System.
Returns:

Output of the loss function/module.

class pytorch_wrapper.loss_wrappers.GenericPointWiseLossWrapper(loss, model_output_key=None, batch_target_key='target', perform_last_activation=False)

Bases: pytorch_wrapper.loss_wrappers.AbstractLossWrapper

Adapter that wraps a pointwise loss module.

Parameters:
  • loss – Loss module.
  • model_output_key – Key where the dict returned by the model contains the actual predictions. Leave None if the model returns only the predictions.
  • batch_target_key – Key where the dict (batch) contains the target values.
  • perform_last_activation – Whether to perform the last_activation.
calculate_loss(output, batch, training_context, last_activation=None)

Calculates the loss for a single batch.

Parameters:
  • batch – Dict that contains all information needed by the loss wrapper.
  • output – Output of the model.
  • training_context – Dict containing information regarding the training process.
  • last_activation – Last activation provided to the System.
Returns:

Output of the loss function/module.

class pytorch_wrapper.loss_wrappers.TokenLabelingGenericPointWiseLossWrapper(loss, batch_input_sequence_length_idx, batch_input_key='input', model_output_key=None, batch_target_key='target', perform_last_activation=False, end_padded=True)

Bases: pytorch_wrapper.loss_wrappers.AbstractLossWrapper

Adapter that wraps a pointwise loss module. It is used in token labeling tasks in order to flat the output and target while discarding invalid values due to padding.

Parameters:
  • loss – Loss module.
  • batch_input_sequence_length_idx – The index of the input list where the lengths of the sequences can be found.
  • batch_input_key – Key of the Dicts returned by the Dataloader objects that corresponds to the input of the model.
  • model_output_key – Key where the dict returned by the model contains the actual predictions. Leave None if the model returns only the predictions.
  • batch_target_key – Key where the dict (batch) contains the target values.
  • perform_last_activation – Whether to perform the last_activation.
  • end_padded – Whether the sequences are end-padded.
calculate_loss(output, batch, training_context, last_activation=None)

Calculates the loss for a single batch.

Parameters:
  • batch – Dict that contains all information needed by the loss wrapper.
  • output – Output of the model.
  • training_context – Dict containing information regarding the training process.
  • last_activation – Last activation provided to the System.
Returns:

Output of the loss function/module.