utils
Utility functions for Hugging Face models.
Module
Functions
get_device_for_model
def get_device_for_model():
train_one_epoch
def train_one_epoch( epoch: int, model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, loss_fn: nn.Module, args: TIMMTrainingConfig, device: torch.device, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, saver: Optional[utils.CheckpointSaver] = None, output_dir: Optional[str] = None, amp_autocast: AbstractContextManager = contextlib.suppress, loss_scaler: Optional[ApexScaler] = None, model_ema: Optional[utils.ModelEmaV2] = None, mixup_fn: Optional[Mixup] = None,) ‑> dict[str, typing.Any]:
Performs one epoch of training and returns loss.
Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)
validate
def validate( model: nn.Module, loader: DataLoader, loss_fn: nn.Module, args: TIMMTrainingConfig, device: torch.device, amp_autocast: AbstractContextManager = contextlib.suppress, log_suffix: str = '',) ‑> dict[str, typing.Any]:
Performs validation of the model and returns metrics.
Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)
Classes
TIMMTrainingConfig
class TIMMTrainingConfig( pretrained: bool = True, initial_checkpoint: str = '', num_classes: Optional[int] = None, gp: Optional[str] = None, img_size: Optional[int] = None, in_chans: Optional[int] = None, input_size: Optional[tuple[int, int, int]] = None, crop_pct: Optional[float] = None, mean: Optional[list[float]] = None, std: Optional[list[float]] = None, interpolation: str = '', batch_size: int = 16, validation_batch_size: Optional[int] = None, channels_last: bool = False, fuser: str = '', grad_accum_steps: int = 1, grad_checkpointing: bool = False, fast_norm: bool = False, model_kwargs: dict[str, Any] = {}, head_init_scale: Optional[float] = None, head_init_bias: Optional[float] = None, torchscript: bool = False, torchcompile: Optional[str] = None, opt: str = 'sgd', opt_eps: Optional[float] = None, opt_betas: Optional[list[float]] = None, momentum: float = 0.9, weight_decay: float = 0.05, clip_grad: Optional[float] = None, clip_mode: str = 'norm', layer_decay: Optional[float] = 0.65, opt_kwargs: dict[str, Any] = {}, sched: str = 'constant_with_warmup', sched_on_updates: bool = False, lr: Optional[float] = 1e-05, lr_base: float = 0.005, lr_base_size: int = 256, lr_base_scale: str = '', lr_noise: Optional[list[float]] = None, lr_noise_pct: float = 0.67, lr_noise_std: float = 1.0, lr_cycle_mul: float = 1.0, lr_cycle_decay: float = 0.5, lr_cycle_limit: int = 1, lr_k_decay: float = 1.0, warmup_lr: float = 1e-05, min_lr: float = 0, epochs: int = 300, epoch_repeats: float = 0.0, start_epoch: Optional[int] = None, decay_milestones: list[int] = [90, 180, 270], decay_epochs: float = 90, warmup_epochs: int = 5, warmup_prefix: bool = False, cooldown_epochs: int = 0, patience_epochs: int = 10, decay_rate: float = 1.0, aug_splits: int = 0, jsd_loss: bool = False, bce_loss: bool = False, bce_target_thresh: Optional[float] = None, resplit: bool = False, mixup: float = 0.0, cutmix: float = 0.0, cutmix_minmax: Optional[list[float]] = None, mixup_prob: float = 1.0, mixup_switch_prob: float = 0.5, mixup_mode: str = 'batch', mixup_off_epoch: int = 0, smoothing: float = 0.1, drop: float = 0.0, drop_connect: Optional[float] = None, drop_path: Optional[float] = 0.2, drop_block: Optional[float] = None, bn_momentum: Optional[float] = None, bn_eps: Optional[float] = None, sync_bn: bool = False, dist_bn: str = 'reduce', split_bn: bool = False, model_ema: bool = False, model_ema_force_cpu: bool = False, model_ema_decay: float = 0.9998, seed: int = 42, log_interval: int = 50, recovery_interval: int = 0, checkpoint_hist: int = 10, workers: int = 4, save_images: bool = False, amp: bool = False, amp_dtype: str = 'float16', amp_impl: str = 'native', no_ddp_bb: bool = False, synchronize_step: bool = False, no_prefetcher: bool = False, eval_metric: str = 'top1', tta: int = 0, local_rank: int = 0,):
Configuration for training a TIMM model.
Variables
- static
amp : bool
- static
amp_dtype : str
- static
amp_impl : str
- static
aug_splits : int
- static
batch_size : int
- static
bce_loss : bool
- static
bce_target_thresh : Optional[float]
- static
bn_eps : Optional[float]
- static
bn_momentum : Optional[float]
- static
channels_last : bool
- static
checkpoint_hist : int
- static
clip_grad : Optional[float]
- static
clip_mode : str
- static
cooldown_epochs : int
- static
crop_pct : Optional[float]
- static
cutmix : float
- static
cutmix_minmax : Optional[list[float]]
- static
decay_epochs : float
- static
decay_milestones : list[int]
- static
decay_rate : float
- static
dist_bn : str
- static
drop : float
- static
drop_block : Optional[float]
- static
drop_connect : Optional[float]
- static
drop_path : Optional[float]
- static
epoch_repeats : float
- static
epochs : int
- static
eval_metric : str
- static
fast_norm : bool
- static
fuser : str
- static
gp : Optional[str]
- static
grad_accum_steps : int
- static
grad_checkpointing : bool
- static
head_init_bias : Optional[float]
- static
head_init_scale : Optional[float]
- static
img_size : Optional[int]
- static
in_chans : Optional[int]
- static
initial_checkpoint : str
- static
input_size : Optional[tuple[int, int, int]]
- static
interpolation : str
- static
jsd_loss : bool
- static
layer_decay : Optional[float]
- static
local_rank : int
- static
log_interval : int
- static
lr : Optional[float]
- static
lr_base : float
- static
lr_base_scale : str
- static
lr_base_size : int
- static
lr_cycle_decay : float
- static
lr_cycle_limit : int
- static
lr_cycle_mul : float
- static
lr_k_decay : float
- static
lr_noise : Optional[list[float]]
- static
lr_noise_pct : float
- static
lr_noise_std : float
- static
mean : Optional[list[float]]
- static
min_lr : float
- static
mixup : float
- static
mixup_mode : str
- static
mixup_off_epoch : int
- static
mixup_prob : float
- static
mixup_switch_prob : float
- static
model_ema : bool
- static
model_ema_decay : float
- static
model_ema_force_cpu : bool
- static
model_kwargs : dict[str, typing.Any]
- static
momentum : float
- static
no_ddp_bb : bool
- static
no_prefetcher : bool
- static
num_classes : Optional[int]
- static
opt : str
- static
opt_betas : Optional[list[float]]
- static
opt_eps : Optional[float]
- static
opt_kwargs : dict[str, typing.Any]
- static
patience_epochs : int
- static
pretrained : bool
- static
recovery_interval : int
- static
resplit : bool
- static
save_images : bool
- static
sched : str
- static
sched_on_updates : bool
- static
seed : int
- static
smoothing : float
- static
split_bn : bool
- static
start_epoch : Optional[int]
- static
std : Optional[list[float]]
- static
sync_bn : bool
- static
synchronize_step : bool
- static
torchcompile : Optional[str]
- static
torchscript : bool
- static
tta : int
- static
validation_batch_size : Optional[int]
- static
warmup_epochs : int
- static
warmup_lr : float
- static
warmup_prefix : bool
- static
weight_decay : float
- static
workers : int