inference_model
PyTorch inference models for Bitfount.
Classes
PytorchInferenceModel
class PytorchInferenceModel(**kwargs: Any):Simple PyTorch inference model for Bitfount.
This class provides a minimal implementation for inference-only models, without requiring PyTorch Lightning or complex inheritance.
Users only need to implement:
- create_model() - Return the PyTorch nn.Module to use
All other methods have sensible defaults for inference.
Inference model for PyTorch.
Ancestors
- bitfount.backends.pytorch.models.inference_model._BaseInferenceModel
- InferrableModelProtocol
- ModelProtocol
- BaseModelProtocol
- typing.Protocol
- typing.Generic
Variables
initialised : bool- Should return True ifinitialise_modelhas been called.
Methods
deserialize
def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any) ‑> None:Inherited from:
InferrableModelProtocol.deserialize :
Deserialises the model.
forward
def forward(self, x: Any) ‑> Any:Forward pass through the model.
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, data_splitter: Optional[DatasetSplitter] = None, context: Optional[TaskContext] = None,) ‑> None:Initialize model and prepare dataloaders for inference.
Arguments
data: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.data_splitter: Optional splitter to use instead of_InferenceSplitter.context: Optional execution context (unused).
predict
def predict( self, data: Optional[BaseSource] = None, **_: Any,) ‑> PredictReturnType:Run inference and return predictions.
Arguments
data: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.
Returns PredictReturnType containing predictions and optional data keys.
Raises
ValueError: If no test dataloader is available.
PytorchLightningInferenceModel
class PytorchLightningInferenceModel( *, datastructure: DataStructure, schema: BitfountSchema, batch_size: int = 32, **kwargs: Any,):PyTorch Lightning inference model for Bitfount.
Inference model for PyTorch.
Ancestors
- pytorch_lightning.core.module.LightningModule
- lightning_fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
- bitfount.backends.pytorch.models.inference_model._BaseInferenceModel
- InferrableModelProtocol
- ModelProtocol
- BaseModelProtocol
- typing.Protocol
- typing.Generic
Variables
initialised : bool- Should return True ifinitialise_modelhas been called.
Methods
deserialize
def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any) ‑> None:Inherited from:
InferrableModelProtocol.deserialize :
Deserialises the model.
forward
def forward(self, x: Any) ‑> Any:Forward pass through the model.
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, data_splitter: Optional[DatasetSplitter] = None, context: Optional[TaskContext] = None,) ‑> None:Initialise ORT session and prepare dataloaders for inference.
Arguments
data: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.data_splitter: Optional splitter to use instead of_InferenceSplitter.context: Optional execution context (unused).
on_test_epoch_end
def on_test_epoch_end(self) ‑> None:Called at the end of the test epoch.
Aggregates the predictions and targets from the test set.
If you are overwriting this method, ensure you set self._test_preds to
maintain compatibility with self._predict_local unless you are overwriting
both of them.
predict
def predict( self, data: Optional[BaseSource] = None, **_: Any,) ‑> PredictReturnType:Run inference and return predictions.
Arguments
data: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.
Returns PredictReturnType containing predictions and optional data keys. Data keys must be present if the datasource is file-based.
Raises
ValueError: If no test dataloader is available.
test_step
def test_step( self, batch: Any, batch_idx: int,) ‑> bitfount.backends.pytorch.models.base_models._TEST_STEP_OUTPUT:Process a single batch during testing/inference.
Override this step as required.
Arguments
batch: The batch databatch_idx: Index of the batch
Returns Dictionary with predictions and targets
trainer_init
def trainer_init(self) ‑> pytorch_lightning.trainer.trainer.Trainer:Initialize PyTorch Lightning trainer.