Skip to main content

nn

Neural Network classes and helper functions for PyTorch.

Module

Functions

get_torchvision_classification_model

def get_torchvision_classification_model(    model_name: str, pretrained: bool, num_classes: int, **kwargs: Any,)> torch.nn.modules.module.Module:

Returns a pre-existing torchvision model.

This function returns the torchvision classification model corresponding to model_name. Importantly, it resizes the final layer to make it appropriate for the number of classes in the task. Since this is different for every model, it must be hard-coded.

Adapted from pytorch docs/tutorials.

Arguments

  • model_name: The name of the torchvision model to return.
  • pretrained: Whether to use a pretrained model.
  • num_classes: The number of classes to classify.
  • ****kwargs**: Additional arguments to pass to the torchvision model.

Returns The torchvision model.

Raises

  • ValueError: If the model name is not recognised.
  • ValueError: If the model reshaping is not implemented yet.