Inference-Only Models: Custom Models for Federated Inference
Welcome to building custom models for federated inference! With Bitfount's inference model classes, you can create powerful federated inference models with minimal code and complexity.
Overview
Bitfount provides two base classes specifically designed for inference-only tasks:
PytorchLightningInferenceModel- PyTorch Lightning-based inference modelPytorchInferenceModel- Simple PyTorch inference model without Lightning dependencies
These classes eliminate the need to implement training-related methods, making it easy to deploy pre-trained models for federated inference.
Prerequisites
!pip install bitfountGetting Started with Inference Models
With Bitfount's Pytorch inference base classes, you only need to implement one method:
from bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass MySimpleInferenceModel(PytorchInferenceModel):    def __init__(self, **kwargs):        super().__init__(**kwargs)    def create_model(self):        """The only method you need to implement!"""        ...Image Classification with PyTorch
Let's set up the environment and create a complete inference model using a pre-trained ResNet:
import loggingimport nest_asyncioimport torchimport torch.nn as nnfrom pathlib import Pathfrom PIL import Imagefrom bitfount import (    BitfountModelReference,    BitfountSchema,    ImageSource,    DataStructure,    ModelInference,    ResultsOnly,    get_pod_schema,    setup_loggers,)from bitfount.backends.pytorch.models.inference_model import (    PytorchLightningInferenceModel,    PytorchInferenceModel,)nest_asyncio.apply()  # Needed because Jupyter also has an asyncio looploggers = setup_loggers([logging.getLogger("bitfount")])Now let's create our inference model by first saving it to a file:
import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass ResNetInferenceModel(PytorchInferenceModel):    """Simple inference model for image classification."""    def __init__(self, n_classes: int = 10, **kwargs):        super().__init__(**kwargs)        self.n_classes = n_classes    def create_model(self) -> nn.Module:        """Create and return a simple CNN model."""        model = nn.Sequential(            nn.AdaptiveAvgPool2d((1, 1)),            nn.Flatten(),            nn.Linear(3, 64),            nn.ReLU(),            nn.Linear(64, self.n_classes),            nn.Softmax(dim=1),        )        model.eval()        return modelTesting Locally with Image Data
Let's test our model locally first:
# Create some image data for testingdatasource = ImageSource(    path="sample_images/",  # Path to your image directory)schema = BitfountSchema(    name="image-inference-demo",)# For image data, specify image columnsforce_stypes = {    "image": ["Pixel Data"],  # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure(    target=None,  # No target for inference, can be skipped    image_cols=["Pixel Data"],  # Specify the image column    selected_cols=["Pixel Data"],  # Add selected columns    # schema_requirements="full"  # Optional)# Initialize our modelmodel = ResNetInferenceModel(    datastructure=datastructure,    schema=schema,    n_classes=2,    batch_size=2,  # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_results = model.predict(data=datasource)print(f"Local inference completed! Got {len(local_results.preds)} predictions")Run Inference on a Pod
Now use your simple model with the existing Bitfount infrastructure:
# Use the image dataset podpod_identifier = "image-datasource"schema = get_pod_schema(pod_identifier)# Create model referencemodel_ref = BitfountModelReference(    model_ref=Path("ResNetInferenceModel.py"),  # Your simple model file    datastructure=datastructure,    schema=schema,)# Run federated inferenceprotocol = ResultsOnly(algorithm=ModelInference(model=model_ref))results = protocol.run(pod_identifiers=[pod_identifier])print("Inference completed!")print(f"Results: {results}")Similar approach can be used for creating a model inheriting from PytorchLightningInferenceModel.
import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import (    PytorchLightningInferenceModel,)class ResNetLightningInferenceModel(PytorchLightningInferenceModel):    """Simple inference model for image classification."""    def __init__(self, n_classes: int = 10, **kwargs):        super().__init__(**kwargs)        self.n_classes = n_classes    def create_model(self) -> nn.Module:        """Create and return a simple CNN model."""        model = nn.Sequential(            nn.AdaptiveAvgPool2d((1, 1)),            nn.Flatten(),            nn.Linear(3, 64),            nn.ReLU(),            nn.Linear(64, self.n_classes),            nn.Softmax(dim=1),        )        model.eval()        return modelLet's test the Lightning model locally:
# Create some image data for testingdatasource = ImageSource(    path="sample_images/",  # Path to your image directory)schema = BitfountSchema(    name="image-inference-demo",)# For image data, specify image columnsforce_stypes = {    "image": ["Pixel Data"],  # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure(    target=None,  # No target for inference, can be skipped    image_cols=["Pixel Data"],  # Specify the image column    selected_cols=["Pixel Data"],  # Add selected columns)# Initialize our modelmodel = ResNetLightningInferenceModel(    datastructure=datastructure,    schema=schema,    n_classes=2,    batch_size=2,  # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_lightning_results = model.predict(data=datasource)print(    f"Local inference completed! Got {len(local_lightning_results.preds)} predictions")Understanding the Two Base Classes
PytorchLightningInferenceModel vs PytorchInferenceModel
Both classes provide the same core functionality but with different underlying architectures:
| Feature | PytorchLightningInferenceModel | PytorchInferenceModel | 
|---|---|---|
| Dependencies | Requires PyTorch Lightning | Pure PyTorch only | 
| Execution | Uses Lightning Trainer | Direct PyTorch execution | 
| GPU/Device Handling | Lightning's automatic device management | Custom device detection | 
| Extensibility | Full Lightning ecosystem (callbacks, loggers) | Simple, direct control | 
Key Architectural Differences
The main difference lies in how the predict() method is implemented.
PytorchLightningInferenceModel:
# Uses PyTorch Lightning under the hooddef predict(self, data=None, **kwargs):    # Uses pl.Trainer.test() internally    self._pl_trainer.test(model=self, dataloaders=self.test_dl)    return PredictReturnType(preds=self._test_preds, keys=self._test_keys)PytorchInferenceModel:
# Direct PyTorch executiondef predict(self, data=None, **kwargs):    # Direct batch processing with torch.no_grad()    with torch.no_grad():        for batch in self.test_dl:            predictions = self.forward(batch_data)            # Process predictions...    return PredictReturnType(preds=all_predictions, keys=all_keys)When to Use Which Base Class
Use PytorchLightningInferenceModel when:
- You want full PyTorch Lightning integration
 - You need Lightning's advanced features (callbacks, logging, etc.)
 - You want to leverage our dataloaders and datasets
 - You prefer Lightning's structured approach to model organization
 
Use PytorchInferenceModel when:
- You want minimal dependencies and faster startup
 - You prefer simple, direct PyTorch code
 - You're building lightweight inference services
 - You need fine-grained control over the inference loop
 - You're deploying in resource-constrained environments
 
Advanced Customization: Overriding Methods
While you only need to implement create_model(), you can override other methods for custom behavior:
Customizing Your Inference Models
Both base classes share common functionality but have different advanced hooks available.
Common Base Functionality (Both Classes)
create_model()- Required Method
Every inference model must implement this abstract method:
- Return your PyTorch model architecture (
nn.Module) - Called automatically during model initialization
 - The model will be moved to appropriate device and set to evaluation mode
 
Shared Public Methods Available for Override:
initialise_model(data, data_splitter, context) - Model Setup
Default behavior:
- Prepares the model for inference
 - Creates data loaders from provided datasource
 - Calls 
create_model()to instantiate your model - Sets up the inference pipeline
 
When to override:
- Custom model initialization logic
 
forward(x) - Model Forward Pass
Default behavior:
- Handles single and multi-image column scenarios
 - Runs input through your created model
 - Returns model predictions
 
When to override:
- Custom input preprocessing
 - Multi-model ensemble logic
 - Special output formatting needs
 
Shared Utility Methods:
split_dataloader_output(data) - Data Parsing
Purpose: Properly extracts input data from dataloader output When to use: Processing batch data in custom methods instead of manual parsing
serialize(filename) and deserialize(content) - Model Persistence
Purpose: Save and load trained model weights Usage: Standard model checkpointing and deployment
PytorchLightningInferenceModel Customization
Lightning-Specific Override Methods:
test_step(batch, batch_idx) - Per-Batch Processing
Default behavior:
- Processes each batch during inference
 - Extracts data and optional keys from batch
 - Runs forward pass and collects results
 - Handles prediction aggregation automatically
 
When to override:
- Custom preprocessing per batch
 - Ensemble predictions across multiple models
 - Custom metrics or logging during inference
 - Special batch result formatting
 
on_test_epoch_end() - End-of-Inference Processing
Default behavior:
- Aggregates all batch results
 - Prepares final prediction outputs
 - Handles key-prediction alignment
 
When to override:
- Custom result aggregation logic
 - Post-inference processing steps
 - Custom validation or filtering
 
predict(data, **kwargs) - Complete Pipeline Control
Default behavior:
- Uses Lightning trainer for inference execution
 - Manages the complete inference workflow
 - Returns formatted prediction results
 
Lightning Benefits:
- Automatic device management through trainer
 - Built-in logging and metrics capabilities
 - Structured approach with hooks and callbacks
 - Easy integration with Lightning ecosystem
 
PytorchInferenceModel Customization
Inference Model Override Methods:
predict(data, **kwargs) - Direct Inference Control
Default behavior:
- Manual batch processing loop with 
torch.no_grad() - Direct device management and model evaluation
 - Explicit prediction collection and formatting
 - No Lightning trainer dependency
 
When to override:
- Fine-grained control over inference loop
 - Custom batch processing logic
 - Memory-efficient streaming inference
 - Integration with non-Lightning workflows
 
Simple Model Benefits:
- No PyTorch Lightning dependency
 - Direct PyTorch control and transparency
 - Explicit device and memory management
 - Faster startup and execution
 
Method Override Guidelines
Start Simple:
- Implement only 
create_model() - Test basic inference functionality
 - Add method overrides only when needed
 
Lightning Model Progression:
- Override 
test_step()for batch-level customization - Override 
on_test_epoch_end()for result aggregation - Override 
predict()for complete pipeline control 
Simple Model Progression:
- Override 
forward()for input/output processing - Override 
initialise_model()for setup customization - Override 
predict()for complete pipeline control 
Best Practices
- Choose the Right Base: Lightning for research, Simple for production
 - Always call 
model.eval()in yourcreate_model()method - Start Minimal: Begin with just 
create_model(), add complexity incrementally - Use Utilities: Leverage 
split_dataloader_output()for robust data handling - Test Locally: Validate all customizations before federated deployment
 - Handle Edge Cases: Consider different input formats and error conditions
 - Document Changes: Comment custom logic for team collaboration
 
You've now learned how to create simple, powerful inference models for federated learning with Bitfount!
Contact our support team at support@bitfount.com if you have any questions.