shim
Pytorch implementation of the tensor shim.
Classes
PyTorchBackendTensorShim
class PyTorchBackendTensorShim():
PyTorch backend shim/bridge for converting from/to PyTorch tensors.
Ancestors
- BackendTensorShim
- abc.ABC
- bitfount.types._BaseSerializableObjectMixIn
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
- static
nested_fields : ClassVar[dict[str, collections.abc.Mapping[str, Any]]]
Static methods
clamp_params
def clamp_params( p: _TensorLike, prime_q: int, precision: int, num_workers: int,) ‑> bitfount.types._TensorLike:
Method for clipping params for secure sharing.
Constrains the parameters for secure sharing to be within the
required range for secure sharing. Used only when
steps_between_parameter_updates
is 1.
Arguments
p
: The tensor to be constrained.prime_q
: The prime use for secret aggregation.precision
: The precision used for secret aggregation.num_workers
: The number of workers taking part in the secure aggregation.
Returns The clamped parameters.
is_tensor
def is_tensor(p: Any) ‑> bool:
See base class.
to_list
def to_list(p: Union[np.ndarray, _TensorLike]) ‑> list[float]:
See base class.
to_numpy
def to_numpy(t: Union[_TensorLike, list[float]]) ‑> numpy.ndarray:
See base class.
to_tensor
def to_tensor(p: Sequence, **kwargs: Any) ‑> bitfount.types._TensorLike:
See base class.