Skip to main content

utils

Utility functions for Hugging Face models.

Module

Functions

get_device_for_model

def get_device_for_model():

train_one_epoch

def train_one_epoch(    epoch: int,    model: nn.Module,    loader: DataLoader,    optimizer: torch.optim.Optimizer,    loss_fn: nn.Module,    args: TIMMTrainingConfig,    device: torch.device,    lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,    saver: Optional[utils.CheckpointSaver] = None,    output_dir: Optional[str] = None,    amp_autocast: AbstractContextManager = contextlib.suppress,    loss_scaler: Optional[ApexScaler] = None,    model_ema: Optional[utils.ModelEmaV2] = None,    mixup_fn: Optional[Mixup] = None,)> dict:

Performs one epoch of training and returns loss.

Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)

validate

def validate(    model: nn.Module,    loader: DataLoader,    loss_fn: nn.Module,    args: TIMMTrainingConfig,    device: torch.device,    amp_autocast: AbstractContextManager = contextlib.suppress,    log_suffix: str = '',)> dict:

Performs validation of the model and returns metrics.

Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)

Classes

TIMMTrainingConfig

class TIMMTrainingConfig(    pretrained: bool = True,    initial_checkpoint: str = '',    num_classes: Optional[int] = None,    gp: Optional[str] = None,    img_size: Optional[int] = None,    in_chans: Optional[int] = None,    input_size: Optional[tuple[int, int, int]] = None,    crop_pct: Optional[float] = None,    mean: Optional[list[float]] = None,    std: Optional[list[float]] = None,    interpolation: str = '',    batch_size: int = 16,    validation_batch_size: Optional[int] = None,    channels_last: bool = False,    fuser: str = '',    grad_accum_steps: int = 1,    grad_checkpointing: bool = False,    fast_norm: bool = False,    model_kwargs: dict[str, Any] = {},    head_init_scale: Optional[float] = None,    head_init_bias: Optional[float] = None,    torchscript: bool = False,    torchcompile: Optional[str] = None,    opt: str = 'sgd',    opt_eps: Optional[float] = None,    opt_betas: Optional[list[float]] = None,    momentum: float = 0.9,    weight_decay: float = 0.05,    clip_grad: Optional[float] = None,    clip_mode: str = 'norm',    layer_decay: Optional[float] = 0.65,    opt_kwargs: dict[str, Any] = {},    sched: str = 'constant_with_warmup',    sched_on_updates: bool = False,    lr: Optional[float] = 1e-05,    lr_base: float = 0.005,    lr_base_size: int = 256,    lr_base_scale: str = '',    lr_noise: Optional[list[float]] = None,    lr_noise_pct: float = 0.67,    lr_noise_std: float = 1.0,    lr_cycle_mul: float = 1.0,    lr_cycle_decay: float = 0.5,    lr_cycle_limit: int = 1,    lr_k_decay: float = 1.0,    warmup_lr: float = 1e-05,    min_lr: float = 0,    epochs: int = 300,    epoch_repeats: float = 0.0,    start_epoch: Optional[int] = None,    decay_milestones: list[int] = [90, 180, 270],    decay_epochs: float = 90,    warmup_epochs: int = 5,    warmup_prefix: bool = False,    cooldown_epochs: int = 0,    patience_epochs: int = 10,    decay_rate: float = 1.0,    aug_splits: int = 0,    jsd_loss: bool = False,    bce_loss: bool = False,    bce_target_thresh: Optional[float] = None,    resplit: bool = False,    mixup: float = 0.0,    cutmix: float = 0.0,    cutmix_minmax: Optional[list[float]] = None,    mixup_prob: float = 1.0,    mixup_switch_prob: float = 0.5,    mixup_mode: str = 'batch',    mixup_off_epoch: int = 0,    smoothing: float = 0.1,    drop: float = 0.0,    drop_connect: Optional[float] = None,    drop_path: Optional[float] = 0.2,    drop_block: Optional[float] = None,    bn_momentum: Optional[float] = None,    bn_eps: Optional[float] = None,    sync_bn: bool = False,    dist_bn: str = 'reduce',    split_bn: bool = False,    model_ema: bool = False,    model_ema_force_cpu: bool = False,    model_ema_decay: float = 0.9998,    seed: int = 42,    log_interval: int = 50,    recovery_interval: int = 0,    checkpoint_hist: int = 10,    workers: int = 4,    save_images: bool = False,    amp: bool = False,    amp_dtype: str = 'float16',    amp_impl: str = 'native',    no_ddp_bb: bool = False,    synchronize_step: bool = False,    no_prefetcher: bool = False,    eval_metric: str = 'top1',    tta: int = 0,    local_rank: int = 0,):

Configuration for training a TIMM model.

Variables

  • static amp : bool
  • static amp_dtype : str
  • static amp_impl : str
  • static aug_splits : int
  • static batch_size : int
  • static bce_loss : bool
  • static bce_target_thresh : Optional[float]
  • static bn_eps : Optional[float]
  • static bn_momentum : Optional[float]
  • static channels_last : bool
  • static checkpoint_hist : int
  • static clip_grad : Optional[float]
  • static clip_mode : str
  • static cooldown_epochs : int
  • static crop_pct : Optional[float]
  • static cutmix : float
  • static cutmix_minmax : Optional[list]
  • static decay_epochs : float
  • static decay_milestones : list
  • static decay_rate : float
  • static dist_bn : str
  • static drop : float
  • static drop_block : Optional[float]
  • static drop_connect : Optional[float]
  • static drop_path : Optional[float]
  • static epoch_repeats : float
  • static epochs : int
  • static eval_metric : str
  • static fast_norm : bool
  • static fuser : str
  • static gp : Optional[str]
  • static grad_accum_steps : int
  • static grad_checkpointing : bool
  • static head_init_bias : Optional[float]
  • static head_init_scale : Optional[float]
  • static img_size : Optional[int]
  • static in_chans : Optional[int]
  • static initial_checkpoint : str
  • static input_size : Optional[tuple]
  • static interpolation : str
  • static jsd_loss : bool
  • static layer_decay : Optional[float]
  • static local_rank : int
  • static log_interval : int
  • static lr : Optional[float]
  • static lr_base : float
  • static lr_base_scale : str
  • static lr_base_size : int
  • static lr_cycle_decay : float
  • static lr_cycle_limit : int
  • static lr_cycle_mul : float
  • static lr_k_decay : float
  • static lr_noise : Optional[list]
  • static lr_noise_pct : float
  • static lr_noise_std : float
  • static mean : Optional[list]
  • static min_lr : float
  • static mixup : float
  • static mixup_mode : str
  • static mixup_off_epoch : int
  • static mixup_prob : float
  • static mixup_switch_prob : float
  • static model_ema : bool
  • static model_ema_decay : float
  • static model_ema_force_cpu : bool
  • static model_kwargs : dict
  • static momentum : float
  • static no_ddp_bb : bool
  • static no_prefetcher : bool
  • static num_classes : Optional[int]
  • static opt : str
  • static opt_betas : Optional[list]
  • static opt_eps : Optional[float]
  • static opt_kwargs : dict
  • static patience_epochs : int
  • static pretrained : bool
  • static recovery_interval : int
  • static resplit : bool
  • static save_images : bool
  • static sched : str
  • static sched_on_updates : bool
  • static seed : int
  • static smoothing : float
  • static split_bn : bool
  • static start_epoch : Optional[int]
  • static std : Optional[list]
  • static sync_bn : bool
  • static synchronize_step : bool
  • static torchcompile : Optional[str]
  • static torchscript : bool
  • static tta : int
  • static validation_batch_size : Optional[int]
  • static warmup_epochs : int
  • static warmup_lr : float
  • static warmup_prefix : bool
  • static weight_decay : float
  • static workers : int