Skip to main content

timm_inference

Hugging Face TIMM inference Algorithm.

Adapted from: https://github.com/huggingface/api-inference-community/

Classes

TIMMInference

class TIMMInference(    model_id: str,    image_column_name: str,    num_classes: Optional[int] = None,    batch_transformations: Optional[list[dict[str, _JSONDict]]] = None,    batch_size: int = 1,    checkpoint_path: Optional[Union[os.PathLike, str]] = None,    class_outputs: Optional[list[str]] = None,):

HuggingFace TIMM Inference Algorithm.

Arguments

  • checkpoint_path: The path to a checkpoint file local to the Pod. Defaults to None.
  • class_outputs: A list of explict class outputs to use as labels. Defaults to None.
  • image_column_name: The column name of the image paths.
  • model_id: The model id to use from the Hugging Face Hub.
  • num_classes: The number of classes in the model. Defaults to None.

Attributes

  • checkpoint_path: The path to a checkpoint file local to the Pod. Defaults to None.
  • class_name: The name of the algorithm class.
  • class_outputs: A list of explict class outputs to use as labels. Defaults to None.
  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • image_column_name: The column name of the image paths.
  • model_id: The model id to use from the Hugging Face Hub.
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})
  • num_classes: The number of classes in the model. Defaults to None.

Ancestors

Variables

Methods


create

def create(self, role: Union[str, Role], **kwargs: Any)> Any:

Create an instance representing the role specified.

modeller

def modeller(    self, **kwargs: Any,)> bitfount.federated.algorithms.hugging_face_algorithms.base._HFModellerSide:

Returns the modeller side of the TIMMInference algorithm.

worker

def worker(    self, **kwargs: Any,)> bitfount.federated.algorithms.hugging_face_algorithms.timm_inference._WorkerSide:

Returns the worker side of the TIMMInference algorithm.