utils.state_dict¶
Functionality to operate on state_dicts
-
borch.utils.state_dict.
add_state_dict_to_state
(module, state)¶ Save the current state of a module
-
borch.utils.state_dict.
copy_state_dict
(state_dict)¶ Takes a copy of a state_dict by serializing and deserializing it using pickle.
- Parameters
state_dict – state_dict to copy
- Returns
copy of state_dict
- Return type
new_state_dict
-
borch.utils.state_dict.
sample_state
(module, state, idx=None)¶ Restore one of the saved states
-
borch.utils.state_dict.
saveable_state_dict
(state_dict)¶ Copies a statedict and puts it on cpu and unobserve all variables in it
- Parameters
state_dict – state_dict that we to save
- Returns
the modified input state_dict
- Return type
state_dict
-
borch.utils.state_dict.
state_dict_to_device_
(state_dict, device)¶ Send all values in a statedict to a given device. When encounters values that contain parameters, it will loop through those parameters and put them on the device as well. Note if the parameters have further sub_parameters we are not recursing into them.
- Parameters
state_dict – A state_dict from an borch.borch.nn.module
- Returns
the modified input state_dict
- Return type
state_dict
Examples
>>> import torch >>> from borch import nn >>> net = nn.Sequential(nn.Conv2d(3,10, 3), nn.Conv2d(10, 11, 3)) >>> state_dict = net.state_dict() >>> new_state_dict = saveable_state_dict(state_dict)