timm_fine_tuning
HuggingFace TIMM fine-tuning Algorithm.
Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)
Classes
TIMMFineTuning
class TIMMFineTuning( datastructure: DataStructure, model_id: str, labels: Optional[list[str]] = None, args: Optional[TIMMTrainingConfig] = None, batch_transformations: Optional[Union[list[Union[str, _JSONDict]], dict[_TimmBatchTransformationStep, list[Union[str, _JSONDict]]]]] = None, return_weights: bool = False, save_path: Optional[Union[str, os.PathLike]] = None,):
HuggingFace TIMM Fine Tuning Algorithm.
Arguments
- **
**kwargs
**: Additional keyword arguments passed to the Worker side. args
: The training configuration.batch_transformations
: The batch transformations to be applied to the batches. Can be a list of strings or a list of dictionaries, which will be applied to both training and validation, or a dictionary with keys "train" and "validation" mapped to a list of strings or a list of dictionaries, specifying the batch transformations to be applied at each individual step. They are only applied ifdatastructure
is not passed. Defaults to apply DEFAULT_IMAGE_TRANSFORMATIONS to both training and validation.datastructure
: The datastructure relating to the dataset to be trained on. Defaults to None.labels
: The labels of the target column. Defaults to None.model_id
: The Hugging Face model ID.return_weights
: Whether to return the weights of the model.save_path
: The path to save the model to.
Attributes
class_name
: The name of the algorithm class.fields_dict
: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict ={"class_name": fields.Str()}
).nested_fields
: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry}
)
Ancestors
- BaseNonModelAlgorithmFactory
- BaseAlgorithmFactory
- abc.ABC
- bitfount.federated.roles._RolesMixIn
- bitfount.types._BaseSerializableObjectMixIn
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
Methods
create
def create(self, role: Union[str, Role], **kwargs: Any) ‑> Any:
Create an instance representing the role specified.
modeller
def modeller( self, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.base._HFModellerSide:
Returns the modeller side of the TIMMFineTuning algorithm.
worker
def worker( self, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.timm_fine_tuning._WorkerSide:
Returns the worker side of the TIMMFineTuning algorithm.