load_ANN
Loads an object saved with torch.save()
or tf.train.Saver()
from a file.
1load_ANN(net, path)
2load_ANN(net, path, **kwargs)
xxxxxxxxxx
11load_ANN(net, path)
Loads the parameters of a well-trained network from a file on the disk.
net
is an instance of the network model that users build.
path
specifies the dictionary where the file is located, relatively or absolutely.
xxxxxxxxxx
11load_ANN(net, path, **kwargs)
More parameters about how to load the network are specified by kwargs
.
For example, source
indicate the framework in which the network is build, PyTorch or TensorFlow.
And map_location
is needed to specify how to remap storage locations in PyTorch.
x1>>> import torch
2# the class `CNN` is user-defined
3>>> from ANN import CNN
4>>> from OpenHA.core.utils import load_ANN
5
6>>> path = 'path_of_the_target_file.pth'
7# instance of this class
8>>> net = CNN()
9# key word argument
10>>> map_location = torch.device('cpu')
11# load
12>>> load_ANN(net, path, source='pytorch', map_location=map_location)
13
net
—— An instance of the network model that users build.
path
—— The dictionary where the file is located, relatively or absolutely.
source
—— The framework in which the network is build, PyTorch or TensorFlow, specified as pytorch
(default) or tensorflow
.
map_location
—— Specify how to remap storage locations in PyTorch.
[1] https://pytorch.org/docs/stable/generated/torch.load.html