loss
Implements loss for pytorch modules.
Module
Functions
soft_dice_loss
def soft_dice_loss( pred: np.ndarray, targets: np.ndarray, square_nom: bool = False, square_denom: bool = False, weight: Optional[Union[Sequence, torch.Tensor]] = None, smooth: float = 1.0,) ‑> torch.Tensor:
Functional implementation of the SoftDiceLoss.
Arguments
pred
: A numpy array of predictions.targets
: A numpy array of targets.square_nom
: Whether to square the nominator.square_denom
: Whether to square the denominator.weight
: Additional weighting of individual classes.smooth
: Smoothing for nominator and denominator.
Returns A torch tensor with the computed dice loss.
Classes
SoftDiceLoss
class SoftDiceLoss( square_nom: bool = False, square_denom: bool = False, weight: Optional[Union[Sequence, torch.Tensor]] = None, smooth: float = 1.0,):
Soft Dice Loss.
The soft dice loss is computed as a fraction of nominator over denominator, where: nominator is 2 * the area of overlap between targets and predictions plus a smooth factor,and the denominator is the total number of pixels in both images plus the smooth factor.If weights are provided the fraction is multiplied by the provided weights for each class.If either square_nom or square_denom are provided, then the respective nominator or denominator will be raised to the power of 2.
Arguments
square_nom
: Whether to square the nominator. Optional.square_denom
: Whether to square the denominator. Optional.weight
: Additional weighting of individual classes. Optional.smooth
: Smoothing for nominator and denominator. Optional.Defaults to 1.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
Methods
forward
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensor:
Computes Soft Dice Loss.
Arguments
predictions
: The predictions obtained by the network.targets
: The targets (ground truth) for the predictions.
Returns torch.Tensor: The computed loss value
Raises
ValueError
: If the predictions tensor has less than 3 dimensions.ValueError
: If the targets tensor has less than 2 dimensions.