utils.module_utils

Common utitility functions to be used with the ppl.Module and its inheritors such as getting the total number of parameters in a Module or making module static or handling module attributes.

borch.utils.module_utils.copy_module_attributes(original, new)

Copy attributes from one module to another. Specifically, ensure that all tensors in the _parameters are assigned the correct class and retain attributes.

Parameters
  • original – Original module to copy attributes from.

  • new – New module to copy attributes to.

Returns

The module copy but with all attributes updated according to original.

borch.utils.module_utils.get_nested_module(module: torch.nn.modules.module.Module, index: Iterable[str]) → torch.nn.modules.module.Module

Get a (potentially) nested child module from a module.

Parameters
  • module – Parent module to index into.

  • index – Index to fetch module by.

Returns

An extracted module.

Example

>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.block = nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 4))
>>>
>>> net = Net()
>>> get_nested_module(net, ("block", "1"))
Linear(in_features=3, out_features=4, bias=True)
borch.utils.module_utils.get_nested_modules(module: torch.nn.modules.module.Module, indices: Iterable[Iterable[str]]) → Tuple[torch.nn.modules.module.Module]

Get multiple (potentially) nested child modules from a module.

Parameters
  • module – Parent module to index into.

  • indices – Indices to fetch modules by.

Returns

A tuple of extracted modules.

Example

>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.block = nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 4))
>>>
>>> net = Net()
>>> modules = get_nested_modules(net, [("block", "0"), ("block", "1")])
>>> modules[0]
Linear(in_features=2, out_features=3, bias=True)
>>> modules[1]
Linear(in_features=3, out_features=4, bias=True)
borch.utils.module_utils.load_state_dict(module: torch.nn.modules.module.Module, state_dict: dict, strict_names: bool = True, strict_shapes: bool = True)

Loads state_dict into module.

We can optionally ignore any parameters which are missing or superfluous, and/or any parameters which have mismatched shapes.

Parameters
  • module – Module to load state_dict into.

  • state_dict – State dict to load.

  • strict_names – If True, an error will be raised if there are any mismatched names betweeen state_dict and the module contents.

  • strict_shapes – If True, an error will be raised if there are any mismatched parameter shapes betweeen state_dict and the module contents.

Example

If we have an architecture in which the weight sizes are expected to differ on a final layer, we can still forgivingly load a state dict as follows:

>>> from io import BytesIO
>>> class Network(nn.Sequential):
...     def __init__(self, n_out):
...         super().__init__()
...         self.one = nn.Linear(3, 4)
...         self.two = nn.Linear(4, 5)
...         self.thr = nn.Linear(5, n_out)
>>>
>>> net1 = Network(n_out=10)
>>> net2 = Network(n_out=20)
>>> state_dict = net1.state_dict()
>>> _ = load_state_dict(net2, state_dict, strict_shapes=False)
borch.utils.module_utils.parameters_named(module, include_name)

Return all parameters where the name name` is part of the name yielded by net.named_paramaters.

Notes

If name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this parameter will be yeilded from this function.

Args:from borch.utils.module_utils import copy_module_attributes

module (torch.nn.Module): the net you want parameters from include_name (str): the name you want to in include

Returns

generator where the named params are filtered away.

borch.utils.module_utils.parameters_not_named(module, remove_name)

Return all paramaters not where the name remove_name is part of the name yielded by net.named_paramaters.

Notes

If remove_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this paramater will not be yeilded from this function.

Parameters
  • module (torch.nn.Module) – the net you want parameters from

  • remove_name (str) – the name you don’t want to in include

Returns

generator where the named params are filtered away.

borch.utils.module_utils.total_parameters(module)

Return the total number of parameters on a torch.nn.Module (typically a neural network).

Parameters

module (torch.nn.Module) – The network for which the number of parameters should be calculated.

Returns

Number of parameters.

Return type

int

Examples

>>> from torch.nn import Linear, Sequential, Sigmoid
>>> net = Sequential(
...     Linear(3, 4, bias=True), Sigmoid(),
...     Linear(4, 5, bias=True), Sigmoid()
... )
>>> total_parameters(net)
41
borch.utils.module_utils.yield_named(named_parameters, include_name)

Yield all named parameters whose names do contain the string include_name.

Notes

If include_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this parameter will be yeilded from this function.

Parameters
  • named_parameters (iterable) – parameters you want to filter, should yeild name, par

  • include_name (str) – the name you want to in include

Returns

generator where the named params are filtered away.

borch.utils.module_utils.yield_not_named(named_parameters, remove_name)

Yield all named parameters whose names do not contain the string remove_name.

Notes

If remove_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this paramater will not be yeilded from this function.

Parameters
  • named_parameters (iterable) – parameters you want to filter, should yeild name, par

  • remove_name (str) – the name you don’t want to in include

Yields

generator where the named params are filtered away.