Skip to main content

bitfount_model

Contains PyTorch implementations of the BitfountModel paradigm.

Classes

PyTorchBitfountModel

class PyTorchBitfountModel(    datastructure: DataStructure,    schema: BitfountSchema,    batch_size: int = 32,    epochs: Optional[int] = None,    steps: Optional[int] = None,    seed: Optional[int] = None,    param_clipping: Optional[dict[str, int]] = None,):

Blueprint for a pytorch custom model in the lightning v1 format.

This class must be subclassed in its own module. A Path to the module containing the subclass can then be passed to BitfountModelReference and on to your Algorithm of choice which will send the model to Bitfount Hub.

To get started, just implement the abstract methods in this class. For more advanced users feel free to override or overwrite any variables/methods in your subclass.

Take a look at the pytorch-lightning documentation on how to properly create a LightningModule:

https://pytorch-lightning.readthedocs.io/en/1.9.5/common/lightning_module.html

caution

Ensure you set self.metrics in the __init__ method of your subclass to ensure they pertain appropriately to your model. If not, Bitfount will attempt to set these appropriately for you but there is no guarantee it will get it right.

Arguments

  • ****kwargs**: Any additional arguments to pass to parent constructors.
  • batch_size: The batch size to use for training. Defaults to 32.
  • datastructure: DataStructure to be passed to the model when initialised
  • epochs: The number of epochs to train for.
  • param_clipping: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g. {"prime_q":13, "precision": 10**3,"num_workers":2}. Defaults to None.
  • schema: The BitfountSchema object associated with the datasource on which the model will be trained on.
  • seed: Random number seed. Used for setting random seed for all libraries. Defaults to None.
  • steps: The number of steps to train for.

Attributes

  • batch_size: The batch size to use for training.
  • epochs: The number of epochs to train for.
  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})
  • preds: The predictions from the most recent test run.
  • steps: The number of steps to train for.
  • target: The targets from the most recent test run.
  • val_stats: Metrics from the validation set during training.

Raises

  • ValueError: If both epochs and steps are specified.

Ancestors

Static methods


diff_params

def diff_params(    old_params: _Weights, new_params: _Weights,)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

See base class.

serialize_model_source_code

def serialize_model_source_code(    filename: Union[str, os.PathLike], extra_imports: Optional[list[str]] = None,)> None:

Inherited from:

BitfountModel.serialize_model_source_code :

Serializes the source code of the model to file.

This is required so that the model source code can be uploaded to Bitfount Hub.

Arguments

  • filename: The filename to save the source code to.
  • extra_imports: A list of extra import statements to include in the source code.

Methods


apply_weight_updates

def apply_weight_updates(    self, weight_updates: Sequence[_Weights],)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

See base class.

deserialize

def deserialize(    self,    content: Union[str, os.PathLike, bytes],    weights_only: bool = True,    **kwargs: Any,)> None:

Deserialize model.

danger

If weights_only is set to False, this should not be used on a model file that has been received across a trust boundary due to underlying use of pickle by torch.

Arguments

  • content: Path to file containing serialized model.
  • weights_only: If True, only load the weights of the model. If False, load the entire model. Defaults to True.
  • **kwargs: Keyword arguments provided to torch.load under the hood.

deserialize_params

def deserialize_params(    self, serialized_weights: _SerializedWeights,)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

Convert serialized model params to tensors.

evaluate

def evaluate(self)> EvaluateReturnType:

Runs evaluation on the datasource.

caution

For remote evaluation, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

fit

def fit(    self, data: BaseSource, metrics: Optional[dict[str, Metric]] = None, **kwargs: Any,)> Optional[dict[str, str]]:

Fits model locally.

caution

For remote training, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

Arguments

  • data: Datasource for training. Defaults to None.
  • metrics: Metrics to calculate for validation. Defaults to None.
  • ****kwargs**: Optional keyword arguments passed to the federated fit method. Any unrecognized keyword arguments will be interpreted as custom model hyperparameters.

Returns A dictionary of metrics and their values. Optional.

Raises

  • ValueError: If neither pod_identifiers are provided for federated training nor data is provided for local training.

get_param_states

def get_param_states(self)> dict[str, bitfount.types._TensorLike]:

See base class.

Wrapping the state dictionary with dict ensures we return a dict rather than an OrderedDict.

initialise_model

def initialise_model(    self,    data: Optional[BaseSource] = None,    data_splitter: Optional[DatasetSplitter] = None,    context: Optional[TaskContext] = None,)> None:

Any initialisation of models/dataloaders to be done here.

Initialises the dataloaders and sets self._model to be the output from self.create_model. Any initialisation ahead of training, serialization or deserialization should be done here.

Arguments

  • data: The datasource for model training. Defaults to None.
  • data_splitter: The splitter to use for the data. Defaults to None.
  • context: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.

log_

def log_(self, name: str, value: Any, **kwargs: Any)> Any:

Simple wrapper around the pytorch lightning log method.

predict

def predict(self, data: BaseSource, **kwargs: Any)> PredictReturnType:

Runs inference on the datasource.

caution

For remote inference, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

Arguments

  • data: Datasource to run inference on if training locally. Defaults to None.
  • ****kwargs**: Optional keyword arguments passed to the federated predict method.

reset_trainer

def reset_trainer(self)> None:

See base class.

serialize

def serialize(self, filename: Union[str, os.PathLike])> None:

Serialize model to file with provided filename.

Arguments

  • filename: Path to file to save serialized model.

serialize_params

def serialize_params(    self, weights: _Weights,)> collections.abc.Mapping[str, list[float]]:

Serialize model params.

set_model_training_iterations

def set_model_training_iterations(self, iterations: int)> None:

See base class.

skip_training_batch

def skip_training_batch(self, batch_idx: int)> bool:

Checks if the current batch from the training set should be skipped.

This is a workaround for the fact that PyTorch Lightning starts the Dataloader iteration from the beginning every time fit is called. This means that if we are training in steps, we are always training on the same batches. So this method needs to be called at the beginning of every training_step to skip to the right batch index.

Arguments

  • batch_idx: the index of the batch from training_step.

Returns True if the batch should be skipped, otherwise False.

tensor_precision

def tensor_precision(self)> +T_DTYPE:

Returns tensor dtype used by Pytorch Lightning Trainer.

note

Currently only 32-bit training is supported.

Returns Pytorch tensor dtype.

test_dataloader

def test_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns test dataloader.

test_epoch_end

def test_epoch_end(self, outputs: list[_TEST_STEP_OUTPUT])> None:

Aggregates the predictions and targets from the test set.

caution

If you are overwriting this method, ensure you set self._test_preds to maintain compatibility with self._predict_local unless you are overwriting both of them.

Arguments

  • outputs: list of outputs from each test step.

test_step

def test_step(    self, batch: Any, batch_idx: int,)> bitfount.backends.pytorch.models.base_models._TEST_STEP_OUTPUT:

Operates on a single batch of data from the test set.

Arguments

  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the test dataloader.

Returns A dictionary of predictions and targets, with the dictionary keys being "predictions" and "targets" for each of them, respectively. These will be passed to the test_epoch_end method.

train_dataloader

def train_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns training dataloader.

trainer_init

def trainer_init(self)> pytorch_lightning.trainer.trainer.Trainer:

Initialises the Lightning Trainer for this model.

Documentation for pytorch-lightning trainer can be found here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html

tip

Override this method to choose your own Trainer arguments.

Returns The pytorch lightning trainer.

training_step

def training_step(    self, batch: Any, batch_idx: int,)> torch.Tensor | dict[str, typing.Any]:

Training step.

caution

If iterations have been specified in terms of steps, the default behaviour of pytorch lightning is to train on the first n steps of the dataloader every time fit is called. This default behaviour is not desirable and has been dealt with in the built-in Bitfount models but, until this bug gets fixed by the pytorch lightning team, this needs to be implemented by the user for custom models.

tip

Take a look at the skip_training_batch method for one way on how to deal with this. It can be used as follows:

if self.skip_training_batch(batch_idx):
return None

Arguments

  • batch: The batch to be trained on.
  • batch_idx: The index of the batch to be trained on from the train dataloader.

Returns The loss from this batch as a torch.Tensor. Or a dictionary which includes the key loss and the loss as a torch.Tensor.

update_params

def update_params(self, new_model_params: _Weights)> None:

See base class.

val_dataloader

def val_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns validation dataloader.

validation_epoch_end

def validation_epoch_end(self, outputs: list[_StrAnyDict])> None:

Called at the end of the validation epoch with all validation step outputs.

Ensures that the average metrics from a validation epoch is stored. Logs results and also appends to self.val_stats.

Arguments

  • outputs: list of outputs from each validation step.

validation_step

def validation_step(self, batch: Any, batch_idx: int)> dict[str, typing.Any]:

Validation step.

Arguments

  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the validation dataloader.

Returns A dictionary of strings and values that should be averaged at the end of every epoch and logged e.g. {"validation_loss": loss}. These will be passed to the validation_epoch_end method.

PyTorchBitfountModelv2

class PyTorchBitfountModelv2(    datastructure: DataStructure,    schema: BitfountSchema,    batch_size: int = 32,    epochs: Optional[int] = None,    steps: Optional[int] = None,    seed: Optional[int] = None,    param_clipping: Optional[dict[str, int]] = None,):

Blueprint for a pytorch custom model in the lightning v2+ format.

This class must be subclassed in its own module. A Path to the module containing the subclass can then be passed to BitfountModelReference and on to your Algorithm of choice which will send the model to Bitfount Hub.

To get started, just implement the abstract methods in this class. For more advanced users feel free to override or overwrite any variables/methods in your subclass.

Take a look at the pytorch-lightning documentation on how to properly create a LightningModule:

https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html

caution

Ensure you set self.metrics in the __init__ method of your subclass to ensure they pertain appropriately to your model. If not, Bitfount will attempt to set these appropriately for you but there is no guarantee it will get it right.

Arguments

  • ****kwargs**: Any additional arguments to pass to parent constructors.
  • batch_size: The batch size to use for training. Defaults to 32.
  • datastructure: DataStructure to be passed to the model when initialised
  • epochs: The number of epochs to train for.
  • param_clipping: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g. {"prime_q":13, "precision": 10**3,"num_workers":2}. Defaults to None.
  • schema: The BitfountSchema object associated with the datasource on which the model will be trained on.
  • seed: Random number seed. Used for setting random seed for all libraries. Defaults to None.
  • steps: The number of steps to train for.

Attributes

  • batch_size: The batch size to use for training.
  • epochs: The number of epochs to train for.
  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})
  • preds: The predictions from the most recent test run.
  • steps: The number of steps to train for.
  • target: The targets from the most recent test run.
  • val_stats: Metrics from the validation set during training.

Raises

  • ValueError: If both epochs and steps are specified.

Ancestors

Static methods


diff_params

def diff_params(    old_params: _Weights, new_params: _Weights,)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

See base class.

serialize_model_source_code

def serialize_model_source_code(    filename: Union[str, os.PathLike], extra_imports: Optional[list[str]] = None,)> None:

Inherited from:

BitfountModel.serialize_model_source_code :

Serializes the source code of the model to file.

This is required so that the model source code can be uploaded to Bitfount Hub.

Arguments

  • filename: The filename to save the source code to.
  • extra_imports: A list of extra import statements to include in the source code.

Methods


apply_weight_updates

def apply_weight_updates(    self, weight_updates: Sequence[_Weights],)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

See base class.

deserialize

def deserialize(    self,    content: Union[str, os.PathLike, bytes],    weights_only: bool = True,    **kwargs: Any,)> None:

Deserialize model.

danger

If weights_only is set to False, this should not be used on a model file that has been received across a trust boundary due to underlying use of pickle by torch.

Arguments

  • content: Path to file containing serialized model.
  • weights_only: If True, only load the weights of the model. If False, load the entire model. Defaults to True.
  • **kwargs: Keyword arguments provided to torch.load under the hood.

deserialize_params

def deserialize_params(    self, serialized_weights: _SerializedWeights,)> collections.abc.Mapping[str, bitfount.types._TensorLike]:

Convert serialized model params to tensors.

evaluate

def evaluate(self)> EvaluateReturnType:

Runs evaluation on the datasource.

caution

For remote evaluation, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

fit

def fit(    self, data: BaseSource, metrics: Optional[dict[str, Metric]] = None, **kwargs: Any,)> Optional[dict[str, str]]:

Fits model locally.

caution

For remote training, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

Arguments

  • data: Datasource for training. Defaults to None.
  • metrics: Metrics to calculate for validation. Defaults to None.
  • ****kwargs**: Optional keyword arguments passed to the federated fit method. Any unrecognized keyword arguments will be interpreted as custom model hyperparameters.

Returns A dictionary of metrics and their values. Optional.

Raises

  • ValueError: If neither pod_identifiers are provided for federated training nor data is provided for local training.

get_param_states

def get_param_states(self)> dict[str, bitfount.types._TensorLike]:

See base class.

Wrapping the state dictionary with dict ensures we return a dict rather than an OrderedDict.

initialise_model

def initialise_model(    self,    data: Optional[BaseSource] = None,    data_splitter: Optional[DatasetSplitter] = None,    context: Optional[TaskContext] = None,)> None:

Any initialisation of models/dataloaders to be done here.

Initialises the dataloaders and sets self._model to be the output from self.create_model. Any initialisation ahead of training, serialization or deserialization should be done here.

Arguments

  • data: The datasource for model training. Defaults to None.
  • data_splitter: The splitter to use for the data. Defaults to None.
  • context: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.

log_

def log_(self, name: str, value: Any, **kwargs: Any)> Any:

Simple wrapper around the pytorch lightning log method.

on_test_epoch_end

def on_test_epoch_end(self)> None:

Called at the end of the test epoch.

Aggregates the predictions and targets from the test set.

caution

If you are overwriting this method, ensure you set self._test_preds to maintain compatibility with self._predict_local unless you are overwriting both of them.

on_train_epoch_end

def on_train_epoch_end(self)> None:

Called at the end of the training epoch.

Default does nothing but clear training step outputs store.

on_validation_epoch_end

def on_validation_epoch_end(self)> None:

Called at the end of the validation epoch.

Ensures that the average metrics from a validation epoch is stored. Logs results and also appends to self.val_stats.

predict

def predict(self, data: BaseSource, **kwargs: Any)> PredictReturnType:

Runs inference on the datasource.

caution

For remote inference, this method does not use the existing parameters of the model unless they are serialized and provided using the pretrained_file keyword argument.

Arguments

  • data: Datasource to run inference on if training locally. Defaults to None.
  • ****kwargs**: Optional keyword arguments passed to the federated predict method.

reset_trainer

def reset_trainer(self)> None:

See base class.

serialize

def serialize(self, filename: Union[str, os.PathLike])> None:

Serialize model to file with provided filename.

Arguments

  • filename: Path to file to save serialized model.

serialize_params

def serialize_params(    self, weights: _Weights,)> collections.abc.Mapping[str, list[float]]:

Serialize model params.

set_model_training_iterations

def set_model_training_iterations(self, iterations: int)> None:

See base class.

skip_training_batch

def skip_training_batch(self, batch_idx: int)> bool:

Checks if the current batch from the training set should be skipped.

This is a workaround for the fact that PyTorch Lightning starts the Dataloader iteration from the beginning every time fit is called. This means that if we are training in steps, we are always training on the same batches. So this method needs to be called at the beginning of every training_step to skip to the right batch index.

Arguments

  • batch_idx: the index of the batch from training_step.

Returns True if the batch should be skipped, otherwise False.

tensor_precision

def tensor_precision(self)> +T_DTYPE:

Returns tensor dtype used by Pytorch Lightning Trainer.

note

Currently only 32-bit training is supported.

Returns Pytorch tensor dtype.

test_dataloader

def test_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns test dataloader.

test_step

def test_step(    self, batch: Any, batch_idx: int,)> bitfount.backends.pytorch.models.base_models._TEST_STEP_OUTPUT:

Operates on a single batch of data from the test set.

This is a wrapper around the _test_step() method which handles output storing.

Arguments

  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the test dataloader.

Returns A dictionary of predictions and targets, with the dictionary keys being "predictions" and "targets" for each of them, respectively. These will be passed to the test_epoch_end method.

train_dataloader

def train_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns training dataloader.

trainer_init

def trainer_init(self)> pytorch_lightning.trainer.trainer.Trainer:

Initialises the Lightning Trainer for this model.

Documentation for pytorch-lightning trainer can be found here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html

tip

Override this method to choose your own Trainer arguments.

Returns The pytorch lightning trainer.

training_step

def training_step(    self, batch: Any, batch_idx: int,)> torch.Tensor | dict[str, typing.Any]:

Training step.

This is a wrapper around the _training_step() method which handles output storing.

caution

If iterations have been specified in terms of steps, the default behaviour of pytorch lightning is to train on the first n steps of the dataloader every time fit is called. This default behaviour is not desirable and has been dealt with in the built-in Bitfount models but, until this bug gets fixed by the pytorch lightning team, this needs to be implemented by the user for custom models.

tip

Take a look at the skip_training_batch method for one way on how to deal with this. It can be used as follows:

if self.skip_training_batch(batch_idx):
return None

Arguments

  • batch: The batch to be trained on.
  • batch_idx: The index of the batch to be trained on from the train dataloader.

Returns The loss from this batch as a torch.Tensor. Or a dictionary which includes the key loss and the loss as a torch.Tensor.

update_params

def update_params(self, new_model_params: _Weights)> None:

See base class.

val_dataloader

def val_dataloader(    self,)> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:

Returns validation dataloader.

validation_step

def validation_step(self, batch: Any, batch_idx: int)> dict[str, typing.Any]:

Validation step.

This is a wrapper around the _validation_step() method which handles output storing.

Arguments

  • batch: The batch to be evaluated.
  • batch_idx: The index of the batch to be evaluated from the validation dataloader.

Returns A dictionary of strings and values that should be averaged at the end of every epoch and logged e.g. {"validation_loss": loss}. These will be passed to the validation_epoch_end method.