nn
Neural Network classes and helper functions for PyTorch.
Module
Functions
get_torchvision_classification_model
def get_torchvision_classification_model( model_name: str, pretrained: bool, num_classes: int, **kwargs: Any,) ‑> torch.nn.modules.module.Module:Returns a pre-existing torchvision model.
This function returns the torchvision classification model corresponding to
model_name. Importantly, it resizes the final layer to make it appropriate
for the number of classes in the task. Since this is different for every model,
it must be hard-coded.
Adapted from pytorch docs/tutorials.
Arguments
model_name: The name of the torchvision model to return.pretrained: Whether to use a pretrained model.num_classes: The number of classes to classify.- **
**kwargs**: Additional arguments to pass to the torchvision model.
Returns The torchvision model.
Raises
ValueError: If the model name is not recognised.ValueError: If the model reshaping is not implemented yet.