Skip to main content

utils

Contains PyTorch specific utility methods.

Module

Functions

autodetect_gpu

def autodetect_gpu()> dict:

Detects and returns GPU accelerator and device count.

Returns A dictionary with the keys 'accelerator' and 'devices' which should be passed to the PyTorchLightning Trainer.

enhanced_torch_load

def enhanced_torch_load(    f: FILE_LIKE,    map_location: MAP_LOCATION = None,    pickle_module: Any = None,    *,    weights_only: bool = True,    **pickle_load_args: Any,)> Any:

Call torch.load() with sensible parameters.

See the docs of torch.load() for more information.