timm_inference
Hugging Face TIMM inference Algorithm.
Adapted from: https://github.com/huggingface/api-inference-community/
Classes
TIMMInference
class TIMMInference( datastructure: DataStructure, model_id: str, num_classes: Optional[int] = None, batch_transformations: Optional[list[dict[str, _JSONDict]]] = None, batch_size: int = 1, checkpoint_path: Optional[Union[os.PathLike, str]] = None, class_outputs: Optional[list[str]] = None,):
HuggingFace TIMM Inference Algorithm.
Arguments
- **
**kwargs
**: Additional keyword arguments. checkpoint_path
: The path to a checkpoint file local to the Pod. Defaults to None.class_outputs
: A list of explict class outputs to use as labels. Defaults to None.datastructure
: The data structure to use for the algorithm.model_id
: The model id to use from the Hugging Face Hub.num_classes
: The number of classes in the model. Defaults to None.
Attributes
checkpoint_path
: The path to a checkpoint file local to the Pod. Defaults to None.class_name
: The name of the algorithm class.class_outputs
: A list of explict class outputs to use as labels. Defaults to None.fields_dict
: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict ={"class_name": fields.Str()}
).model_id
: The model id to use from the Hugging Face Hub.nested_fields
: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry}
)num_classes
: The number of classes in the model. Defaults to None.
Ancestors
- BaseNonModelAlgorithmFactory
- BaseAlgorithmFactory
- abc.ABC
- bitfount.federated.roles._RolesMixIn
- bitfount.types._BaseSerializableObjectMixIn
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
Methods
create
def create(self, role: Union[str, Role], **kwargs: Any) ‑> Any:
Create an instance representing the role specified.
modeller
def modeller( self, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.base._HFModellerSide:
Returns the modeller side of the TIMMInference algorithm.
worker
def worker( self, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.timm_inference._WorkerSide:
Returns the worker side of the TIMMInference algorithm.