bitfount_model
Contains PyTorch implementations of the BitfountModel paradigm.
Classes
PyTorchBitfountModel
class PyTorchBitfountModel( datastructure: DataStructure, schema: BitfountSchema, batch_size: int = 32, epochs: Optional[int] = None, steps: Optional[int] = None, seed: Optional[int] = None, param_clipping: Optional[dict[str, int]] = None,):
Blueprint for a pytorch custom model in the lightning format.
This class must be subclassed in its own module. A Path
to the module containing
the subclass can then be passed to BitfountModelReference
and on to your
Algorithm
of choice which will send the model to Bitfount Hub.
To get started, just implement the abstract methods in this class. For more advanced users feel free to override or overwrite any variables/methods in your subclass.
Take a look at the pytorch-lightning documentation on how to properly create a
LightningModule
:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
Ensure you set self.metrics
in the __init__
method of your subclass to ensure
they pertain appropriately to your model. If not, Bitfount will attempt to set
these appropriately for you but there is no guarantee it will get it right.
Arguments
- **
**kwargs
**: Any additional arguments to pass to parent constructors. batch_size
: The batch size to use for training. Defaults to 32.datastructure
:DataStructure
to be passed to the model when initialisedepochs
: The number of epochs to train for.param_clipping
: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g.{"prime_q":13, "precision": 10**3,"num_workers":2}
. Defaults to None.schema
: TheBitfountSchema
object associated with the datasource on which the model will be trained on.seed
: Random number seed. Used for setting random seed for all libraries. Defaults to None.steps
: The number of steps to train for.
Attributes
batch_size
: The batch size to use for training.epochs
: The number of epochs to train for.fields_dict
: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict ={"class_name": fields.Str()}
).nested_fields
: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry}
)preds
: The predictions from the most recent test run.steps
: The number of steps to train for.target
: The targets from the most recent test run.val_stats
: Metrics from the validation set during training.
Raises
ValueError
: If bothepochs
andsteps
are specified.
Ancestors
- bitfount.backends.pytorch.federated.mixins._PyTorchDistributedModelMixIn
- bitfount.federated.mixins._DistributedModelMixIn
- BitfountModel
- bitfount.models.base_models._BaseModel
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
- abc.ABC
- pytorch_lightning.core.module.LightningModule
- typing.Generic
- lightning_fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.saving.ModelIO
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
Variables
- static
fields_dict : ClassVar[T_FIELDS_DICT]
- static
train_dl : _BasePyTorchBitfountDataLoader
Static methods
diff_params
def diff_params(old_params: _Weights, new_params: _Weights) ‑> collections.abc.Mapping:
See base class.
serialize_model_source_code
def serialize_model_source_code( filename: Union[str, os.PathLike], extra_imports: Optional[list[str]] = None,) ‑> None:
Inherited from:
BitfountModel.serialize_model_source_code :
Serializes the source code of the model to file.
This is required so that the model source code can be uploaded to Bitfount Hub.
Arguments
filename
: The filename to save the source code to.extra_imports
: A list of extra import statements to include in the source code.
Methods
apply_weight_updates
def apply_weight_updates( self, weight_updates: Sequence[_Weights],) ‑> collections.abc.Mapping:
See base class.
configure_optimizers
def configure_optimizers( self,) ‑> Union[torch.optim.optimizer.Optimizer, tuple[list[torch.optim.optimizer.Optimizer], list[torch.optim.lr_scheduler._LRScheduler]]]:
Configures the optimizer(s) and scheduler(s) for backpropagation.
Returns Either the optimizer of your choice or a tuple of optimizers and learning rate schedulers.
create_model
def create_model(self) ‑> torch.nn.modules.module.Module:
Creates and returns the underlying pytorch model.
Returns
Underlying pytorch model. This is set to self._model
.
deserialize
def deserialize( self, content: Union[str, os.PathLike, bytes], weights_only: bool = True, **kwargs: Any,) ‑> None:
Deserialize model.
If weights_only
is set to False, this should not be used on a model file that
has been received across a trust boundary due to underlying use of pickle
by
torch
.
Arguments
content
: Path to file containing serialized model.weights_only
: If True, only load the weights of the model. If False, load the entire model. Defaults to True.**kwargs
: Keyword arguments provided totorch.load
under the hood.
deserialize_params
def deserialize_params( self, serialized_weights: _SerializedWeights,) ‑> collections.abc.Mapping:
Convert serialized model params to tensors.
evaluate
def evaluate( self, test_dl: Optional[BitfountDataLoader] = None, pod_identifiers: Optional[list[str]] = None, **kwargs: Any,) ‑> Union[EvaluateReturnType, dict[str, float]]:
Evaluates model either locally or federated-ly.
pod_identifiers
must be provided for federated evaluation.
For remote evaluation, this method does not use the existing parameters of the
model unless they are serialized and provided using the pretrained_file
keyword argument.
Arguments
test_dl
: Optional dataloader to run inference on which takes precedence over the dataloader returned byself.test_dataloader
.pod_identifiers
: list of pod identifiers. If this is provided, the model will be evaluated in a federated manner. Defaults to None.- **
**kwargs
**: Optional keyword arguments passed to the federated evaluate method.
fit
def fit( self, data: Optional[BaseSource] = None, metrics: Optional[dict[str, Metric]] = None, pod_identifiers: Optional[list[str]] = None, **kwargs: Any,) ‑> Optional[dict[str, str]]:
Fits model either locally or federated-ly.
pod_identifiers
must be provided for federated training.
For remote training, this method does not use the existing parameters of the
model unless they are serialized and provided using the pretrained_file
keyword argument.
Arguments
data
: Datasource for training. Defaults to None.metrics
: Metrics to calculate for validation. Defaults to None.pod_identifiers
: list of pod identifiers. If this is provided, the model will be trained in a federated manner. Defaults to None.- **
**kwargs
**: Optional keyword arguments passed to the federated fit method. Any unrecognized keyword arguments will be interpreted as custom model hyperparameters.
Returns A dictionary of metrics and their values. Optional.
Raises
ValueError
: If neitherpod_identifiers
are provided for federated training nordata
is provided for local training.
forward
def forward(self, x: Any) ‑> Any:
Forward method of the model - just like a regular torch.nn.Module
class.
This will depend on your model but could be as simple as:
return self._model(x)
Arguments
x
: Input to the model.
Returns Output of the model.
get_param_states
def get_param_states(self) ‑> dict:
See base class.
Wrapping the state dictionary with dict
ensures we return a dict
rather than
an OrderedDict
.
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, context: Optional[TaskContext] = None,) ‑> None:
Any initialisation of models/dataloaders to be done here.
Initialises the dataloaders and sets self._model
to be the output from
self.create_model
. Any initialisation ahead of training,
serialization or deserialization should be done here.
Arguments
data
: The datasource for model training. Defaults to None.context
: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker.
log_
def log_(self, name: str, value: Any, **kwargs: Any) ‑> Any:
Simple wrapper around the pytorch lightning log
method.
predict
def predict( self, data: Optional[BaseSource] = None, pod_identifiers: Optional[list[str]] = None, **kwargs: Any,) ‑> Union[PredictReturnType, dict[str, list[np.ndarray]]]:
Infers model either locally or federated-ly.
pod_identifiers
must be provided for federated inference.
For remote inference, this method does not use the existing parameters of the
model unless they are serialized and provided using the pretrained_file
keyword argument.
Arguments
data
: Optional datasource to run inference on if training locally. Defaults to None.pod_identifiers
: list of pod identifiers. If this is provided, the model will be inferred in a federated manner. Defaults to None.- **
**kwargs
**: Optional keyword arguments passed to the federated predict method.
reset_trainer
def reset_trainer(self) ‑> None:
See base class.
serialize
def serialize(self, filename: Union[str, os.PathLike]) ‑> None:
Serialize model to file with provided filename
.
Arguments
filename
: Path to file to save serialized model.
serialize_params
def serialize_params(self, weights: _Weights) ‑> collections.abc.Mapping:
Serialize model params.
set_datastructure_identifier
def set_datastructure_identifier(self, datastructure_identifier: str) ‑> None:
Sets the datastructure identifier for the model.
This is used to identify which elements of the datastructure are referring to the current runner and is normally the pod identifier (for single datasource pods) or the "logical pod" identifier for pods with multiple datasources that have registered as different pods.
This must be called on the Pod/Worker side in Distributed training because it is needed for the model to be able to extract the relevant information from the datastructure sent by the Modeller.
Arguments
datastructure_identifier
: The datastructure identifier for the model.
set_model_training_iterations
def set_model_training_iterations(self, iterations: int) ‑> None:
See base class.
skip_training_batch
def skip_training_batch(self, batch_idx: int) ‑> bool:
Checks if the current batch from the training set should be skipped.
This is a workaround for the fact that PyTorch Lightning starts the Dataloader
iteration from the beginning every time fit
is called. This means that if we
are training in steps, we are always training on the same batches. So this
method needs to be called at the beginning of every training_step
to skip
to the right batch index.
Arguments
batch_idx
: the index of the batch fromtraining_step
.
Returns True if the batch should be skipped, otherwise False.
tensor_precision
def tensor_precision(self) ‑> +T_DTYPE:
Returns tensor dtype used by Pytorch Lightning Trainer.
Currently only 32-bit training is supported.
Returns Pytorch tensor dtype.
test_dataloader
def test_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns test dataloader.
test_epoch_end
def test_epoch_end( self, outputs: list[dict[str, Union[torch.Tensor, list[str]]]],) ‑> None:
Aggregates the predictions and targets from the test set.
If you are overwriting this method, ensure you set self._test_preds
to
maintain compatibility with self._predict_local
unless you are overwriting
both of them.
Arguments
outputs
: list of outputs from each test step.
test_step
def test_step(self, batch: Any, batch_idx: int) ‑> Union[torch.Tensor, dict[str, Any]]:
Operates on a single batch of data from the test set.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the test dataloader.
Returns
A dictionary of predictions and targets, with the dictionary
keys being "predictions" and "targets" for each of them, respectively.
These will be passed to the test_epoch_end
method.
train_dataloader
def train_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns training dataloader.
trainer_init
def trainer_init(self) ‑> pytorch_lightning.trainer.trainer.Trainer:
Initialises the Lightning Trainer for this model.
Documentation for pytorch-lightning trainer can be found here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
Override this method to choose your own Trainer
arguments.
Returns The pytorch lightning trainer.
training_step
def training_step( self, batch: Any, batch_idx: int,) ‑> Union[torch.Tensor, dict[str, Any]]:
Training step.
If iterations have been specified in terms of steps, the default behaviour of
pytorch lightning is to train on the first n steps of the dataloader every
time fit
is called. This default behaviour is not desirable and has been dealt
with in the built-in Bitfount models but, until this bug gets fixed by the
pytorch lightning team, this needs to be implemented by the user for custom
models.
Take a look at the skip_training_batch
method for one way on how to deal with
this. It can be used as follows:
if self.skip_training_batch(batch_idx):
return None
Arguments
batch
: The batch to be trained on.batch_idx
: The index of the batch to be trained on from the train dataloader.
Returns
The loss from this batch as a torch.Tensor
. Or a dictionary which includes
the key loss
and the loss as a torch.Tensor
.
update_params
def update_params(self, new_model_params: _Weights) ‑> None:
See base class.
val_dataloader
def val_dataloader( self,) ‑> bitfount.backends.pytorch.data.dataloaders._BasePyTorchBitfountDataLoader:
Returns validation dataloader.
validation_epoch_end
def validation_epoch_end(self, outputs: list[_StrAnyDict]) ‑> None:
Called at the end of the validation epoch with all validation step outputs.
Ensures that the average metrics from a validation epoch is stored. Logs results
and also appends to self.val_stats
.
Arguments
outputs
: list of outputs from each validation step.
validation_step
def validation_step(self, batch: Any, batch_idx: int) ‑> dict:
Validation step.
Arguments
batch
: The batch to be evaluated.batch_idx
: The index of the batch to be evaluated from the validation dataloader.
Returns
A dictionary of strings and values that should be averaged at the end of
every epoch and logged e.g. {"validation_loss": loss}
. These will be
passed to the validation_epoch_end
method.