base_models
Defines abstract models, mixins, and other common backend-agnostic classes.
Implementations of these abstract models should be located in bitfount.models.models
or in the models
subpackage of a backend.
Classes
ClassifierMixIn
class ClassifierMixIn( multilabel: bool = False, param_clipping: Optional[dict[str, int]] = None,):
MixIn for classification problems.
Classification models must have this class in their inheritance hierarchy.
Arguments
multilabel
: Whether the problem is a multi-label problem. i.e. each datapoint belongs to multiple classesparam_clipping
: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g.{"prime_q":13, "precision": 10**3,"num_workers":2}
Attributes
multilabel
: Whether the problem is a multi-label problemn_classes
: Number of classes in the problem
Ancestors
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
Subclasses
Variables
- static
datastructure : DataStructure
- set in _BaseModel
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
- static
nested_fields : ClassVar[dict[str, collections.abc.Mapping[str, Any]]]
- static
schema : BitfountSchema
- set in _BaseModel
Methods
set_number_of_classes
def set_number_of_classes(self, schema: BitfountSchema) ‑> None:
Sets the target number of classes for the classifier.
If the data is a multi-label problem, the number of classes is set to the number
of target columns as specified in the DataStructure
. Otherwise, the number of
classes is set to the number of unique values in the target column as specified
in the BitfountSchema
. The value is stored in the n_classes
attribute.
LoggerConfig
class LoggerConfig( name: str, save_dir: Optional[Path] = PosixPath('bitfount_logs'), params: Optional[_StrAnyDict] = {},):
Configuration for the logger.
The configured logger will log training events, metrics, model checkpoints, etc. to your chosen platform. If no logger configuration is provided, the default logger is a Tensorboard logger.
Arguments
name
: The name of the logger. Should be one of the loggers supported by the chosen backendsave_dir
: The directory to save the logs. Defaults toconfig.settings.paths.logs_dir
params
: A dictionary of keyword arguments to pass to the logger. Defaults to an empty dictionary
Variables
- static
name : str
- same as argument
- static
params : Optional[dict[str, typing.Any]]
- same as argument
- static
save_dir : Optional[pathlib.Path]
- same as argument
RegressorMixIn
class RegressorMixIn():
MixIn for regression problems.
Currently, just used for tagging purposes.
Ancestors
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn