Documentation for nfn.common
nfn.common.NetworkSpec
dataclass
Specifies the shape of each weight and bias. Permutable dimensions can be specified with a -1
,
since they do not need to be exactly specified.
get_io()
Returns the input and output dimensions of the network.
get_num_params()
Returns the number of parameters in the network.
nfn.common.WeightSpaceFeatures
Bases: collections.abc.Sequence
__init__(weights, biases)
The input for an NF-Layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights |
list or tuple
|
List or tuple of tensors. The first two dimensions of each tensor must be B and C, respectively. It can be any quantity in the weight space, such as weights, gradients, activations, or sparsity masks. |
required |
biases |
list or tuple
|
List or tuple of tensors with same length as |
required |
detach()
Returns a copy with detached tensors.
from_zipped(weight_and_biases)
classmethod
Converts a list of (weights, biases) into a WeightSpaceFeatures object.
map(func)
Applies func
to each weight and bias tensor.
to(device)
Moves all tensors to device
.
nfn.common.state_dict_to_tensors(state_dict)
Converts a state dict into two equal-length lists, one of weights and the other of biases
(or None
if no bias is present).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
OrderedDict
|
State dict to convert. Assumes that keys are ordered according to
|
required |
Returns:
Type | Description |
---|---|
list[Tensor]
|
List of weight tensors. Length is equal to the number of layers. |
list[Tensor]
|
List of bias tensors ( |
nfn.common.params_to_state_dicts(keys, wsfeat)
Converts a WeightSpaceFeatures object into a list of corresponding state dicts, one for each batch element.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
keys |
list[str]
|
Iterable of key names to use for the state dicts. Assumes that key names are ordered
according to |
required |
wsfeat |
WeightSpaceFeatures
|
Weight space features to convert. |
required |
Returns:
Type | Description |
---|---|
NetworkSpec
|
Output network specification. |
nfn.common.network_spec_from_wsfeat(wsfeat, set_all_dims=False)
Converts a WeightSpaceFeatures object into a NetworkSpec object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
wsfeat |
WeightSpaceFeatures
|
Weight space features to convert. |
required |
set_all_dims |
bool
|
If |
False
|
Returns:
Type | Description |
---|---|
NetworkSpec
|
Output network specification. |