utils.state_dict

Functionality to operate on state_dicts

borch.utils.state_dict.add_state_dict_to_state(module, state)

Save the current state of a module

borch.utils.state_dict.copy_state_dict(state_dict)

Takes a copy of a state_dict by serializing and deserializing it using pickle.

Parameters

state_dict – state_dict to copy

Returns

copy of state_dict

Return type

new_state_dict

borch.utils.state_dict.sample_state(module, state, idx=None)

Restore one of the saved states

borch.utils.state_dict.saveable_state_dict(state_dict)

Copies a statedict and puts it on cpu and unobserve all variables in it

Parameters

state_dict – state_dict that we to save

Returns

the modified input state_dict

Return type

state_dict

borch.utils.state_dict.state_dict_to_device_(state_dict, device)

Send all values in a statedict to a given device. When encounters values that contain parameters, it will loop through those parameters and put them on the device as well. Note if the parameters have further sub_parameters we are not recursing into them.

Parameters

state_dict – A state_dict from an borch.borch.nn.module

Returns

the modified input state_dict

Return type

state_dict

Examples

>>> import torch
>>> from borch import nn
>>> net = nn.Sequential(nn.Conv2d(3,10, 3), nn.Conv2d(10, 11, 3))
>>> state_dict = net.state_dict()
>>> new_state_dict = saveable_state_dict(state_dict)