utils.module_utils¶
Common utitility functions to be used with the ppl.Module and its inheritors such as getting the total number of parameters in a Module or making module static or handling module attributes.
-
borch.utils.module_utils.
copy_module_attributes
(original, new)¶ Copy attributes from one module to another. Specifically, ensure that all tensors in the _parameters are assigned the correct class and retain attributes.
- Parameters
original – Original module to copy attributes from.
new – New module to copy attributes to.
- Returns
The module copy but with all attributes updated according to original.
-
borch.utils.module_utils.
get_nested_module
(module: torch.nn.modules.module.Module, index: Iterable[str]) → torch.nn.modules.module.Module¶ Get a (potentially) nested child module from a module.
- Parameters
module – Parent module to index into.
index – Index to fetch module by.
- Returns
An extracted module.
Example
>>> class Net(nn.Module): ... def __init__(self): ... super().__init__() ... self.block = nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 4)) >>> >>> net = Net() >>> get_nested_module(net, ("block", "1")) Linear(in_features=3, out_features=4, bias=True)
-
borch.utils.module_utils.
get_nested_modules
(module: torch.nn.modules.module.Module, indices: Iterable[Iterable[str]]) → Tuple[torch.nn.modules.module.Module]¶ Get multiple (potentially) nested child modules from a module.
- Parameters
module – Parent module to index into.
indices – Indices to fetch modules by.
- Returns
A tuple of extracted modules.
Example
>>> class Net(nn.Module): ... def __init__(self): ... super().__init__() ... self.block = nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 4)) >>> >>> net = Net() >>> modules = get_nested_modules(net, [("block", "0"), ("block", "1")]) >>> modules[0] Linear(in_features=2, out_features=3, bias=True) >>> modules[1] Linear(in_features=3, out_features=4, bias=True)
-
borch.utils.module_utils.
load_state_dict
(module: torch.nn.modules.module.Module, state_dict: dict, strict_names: bool = True, strict_shapes: bool = True)¶ Loads
state_dict
intomodule
.We can optionally ignore any parameters which are missing or superfluous, and/or any parameters which have mismatched shapes.
- Parameters
module – Module to load
state_dict
into.state_dict – State dict to load.
strict_names – If
True
, an error will be raised if there are any mismatched names betweeenstate_dict
and the module contents.strict_shapes – If
True
, an error will be raised if there are any mismatched parameter shapes betweeenstate_dict
and the module contents.
Example
If we have an architecture in which the weight sizes are expected to differ on a final layer, we can still forgivingly load a state dict as follows:
>>> from io import BytesIO >>> class Network(nn.Sequential): ... def __init__(self, n_out): ... super().__init__() ... self.one = nn.Linear(3, 4) ... self.two = nn.Linear(4, 5) ... self.thr = nn.Linear(5, n_out) >>> >>> net1 = Network(n_out=10) >>> net2 = Network(n_out=20) >>> state_dict = net1.state_dict() >>> _ = load_state_dict(net2, state_dict, strict_shapes=False)
-
borch.utils.module_utils.
parameters_named
(module, include_name)¶ Return all parameters where the name name` is part of the name yielded by net.named_paramaters.
Notes
If name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this parameter will be yeilded from this function.
- Args:from borch.utils.module_utils import copy_module_attributes
module (torch.nn.Module): the net you want parameters from include_name (str): the name you want to in include
- Returns
generator where the named params are filtered away.
-
borch.utils.module_utils.
parameters_not_named
(module, remove_name)¶ Return all paramaters not where the name remove_name is part of the name yielded by net.named_paramaters.
Notes
If remove_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this paramater will not be yeilded from this function.
- Parameters
module (torch.nn.Module) – the net you want parameters from
remove_name (str) – the name you don’t want to in include
- Returns
generator where the named params are filtered away.
-
borch.utils.module_utils.
total_parameters
(module)¶ Return the total number of parameters on a torch.nn.Module (typically a neural network).
- Parameters
module (torch.nn.Module) – The network for which the number of parameters should be calculated.
- Returns
Number of parameters.
- Return type
int
Examples
>>> from torch.nn import Linear, Sequential, Sigmoid >>> net = Sequential( ... Linear(3, 4, bias=True), Sigmoid(), ... Linear(4, 5, bias=True), Sigmoid() ... ) >>> total_parameters(net) 41
-
borch.utils.module_utils.
yield_named
(named_parameters, include_name)¶ Yield all named parameters whose names do contain the string include_name.
Notes
If include_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this parameter will be yeilded from this function.
- Parameters
named_parameters (iterable) – parameters you want to filter, should yeild name, par
include_name (str) – the name you want to in include
- Returns
generator where the named params are filtered away.
-
borch.utils.module_utils.
yield_not_named
(named_parameters, remove_name)¶ Yield all named parameters whose names do not contain the string remove_name.
Notes
If remove_name = ‘scale’ and the yeilded name is linerar.weight.scale.u_tensor then this paramater will not be yeilded from this function.
- Parameters
named_parameters (iterable) – parameters you want to filter, should yeild name, par
remove_name (str) – the name you don’t want to in include
- Yields
generator where the named params are filtered away.