base_models
Base models and helper classes using PyTorch as the backend.
Classes
PyTorchClassifierMixIn
class PyTorchClassifierMixIn( multilabel: bool = False, param_clipping: Optional[dict[str, int]] = None,):
MixIn for PyTorch classification problems.
PyTorch classification models must have this class in their inheritance hierarchy.
Arguments
multilabel
: Whether the problem is a multi-label problem. i.e. each datapoint belongs to multiple classesparam_clipping
: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g.{"prime_q":13, "precision": 10**3,"num_workers":2}
Attributes
fields_dict
: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict ={"class_name": fields.Str()}
).multilabel
: Whether the problem is a multi-label problemn_classes
: Number of classes in the problemnested_fields
: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry}
)
Ancestors
- ClassifierMixIn
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
Variables
- static
datastructure : DataStructure
- set in _BaseModel
- static
schema : BitfountSchema
- set in _BaseModel
Methods
set_number_of_classes
def set_number_of_classes(self, schema: TableSchema) ‑> None:
Inherited from:
ClassifierMixIn.set_number_of_classes :
Sets the target number of classes for the classifier.
If the data is a multi-label problem, the number of classes is set to the number
of target columns as specified in the DataStructure
. Otherwise, the number of
classes is set to the number of unique values in the target column as specified
in the BitfountSchema
. The value is stored in the n_classes
attribute.