Skip to content

Documentation for nfn.common

nfn.common.NetworkSpec dataclass

Specifies the shape of each weight and bias. Permutable dimensions can be specified with a -1, since they do not need to be exactly specified.

get_io()

Returns the input and output dimensions of the network.

get_num_params()

Returns the number of parameters in the network.

nfn.common.WeightSpaceFeatures

Bases: collections.abc.Sequence

__init__(weights, biases)

The input for an NF-Layer.

Parameters:

Name Type Description Default
weights list or tuple

List or tuple of tensors. The first two dimensions of each tensor must be B and C, respectively. It can be any quantity in the weight space, such as weights, gradients, activations, or sparsity masks.

required
biases list or tuple

List or tuple of tensors with same length as weights. The first two dimensions of each tensor must be B and C, respectively.

required

detach()

Returns a copy with detached tensors.

from_zipped(weight_and_biases) classmethod

Converts a list of (weights, biases) into a WeightSpaceFeatures object.

map(func)

Applies func to each weight and bias tensor.

to(device)

Moves all tensors to device.

nfn.common.state_dict_to_tensors(state_dict)

Converts a state dict into two equal-length lists, one of weights and the other of biases (or None if no bias is present).

Parameters:

Name Type Description Default
state_dict OrderedDict

State dict to convert. Assumes that keys are ordered according to [weight1, bias1, ..., weightL, biasL].

required

Returns:

Type Description
list[Tensor]

List of weight tensors. Length is equal to the number of layers.

list[Tensor]

List of bias tensors (None if no bias).

nfn.common.params_to_state_dicts(keys, wsfeat)

Converts a WeightSpaceFeatures object into a list of corresponding state dicts, one for each batch element.

Parameters:

Name Type Description Default
keys list[str]

Iterable of key names to use for the state dicts. Assumes that key names are ordered according to [weight1, bias1, ..., weightL, biasL].

required
wsfeat WeightSpaceFeatures

Weight space features to convert.

required

Returns:

Type Description
NetworkSpec

Output network specification.

nfn.common.network_spec_from_wsfeat(wsfeat, set_all_dims=False)

Converts a WeightSpaceFeatures object into a NetworkSpec object.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Weight space features to convert.

required
set_all_dims bool

If True, all dimensions in output NetworkSpec are specified (i.e., output NetworkSpec does not contain -1). Default: False.

False

Returns:

Type Description
NetworkSpec

Output network specification.