Skip to content

Documentation for nfn.layers

nfn.layers.NPPool

Bases: nn.Module

__init__(network_spec, agg='mean')

Parameters:

Name Type Description Default
network_spec NetworkSpec

Network specification.

required
agg str

Type of pooling to perform. One of "mean", "max", or "sum". Default: "mean".

'mean'

forward(wsfeat)

Applies the pooling operation to the input weight space features across any axis that has permutation symmetry. The pooling operation is invariant to S\mathcal{S}, the neuron permutation (NP) group. See Equation 5 for a complete description.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Input weight space features.

required

Returns:

Type Description
torch.Tensor

Output of pooling operation. Shape is (B,C,N)(B, C, N), where NN is the number of outputs of the global pooling layer.

get_num_outs(network_spec) staticmethod

Returns the number of outputs of the global pooling layer.

nfn.layers.HNPPool

Bases: nn.Module

__init__(network_spec, agg='mean')

Parameters:

Name Type Description Default
network_spec NetworkSpec

Network specification.

required
agg str

Type of pooling to perform. One of "mean", "max", or "sum". Default: "mean".

'mean'

forward(wsfeat)

Applies a pooling operation to the input weight space features across any axis that has permutation symmetry. The pooling operation is invariant to S~\mathcal{\tilde{S}}, the hidden neuron permutation (HNP) group. See Equation 20 for a complete description.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Input weight space features.

required

Returns:

Type Description
torch.Tensor

Output tensor with shape (B,C,N)(B, C, N), where NN is the number of outputs of the global pooling layer.

get_num_outs(network_spec) staticmethod

Returns the number of outputs of the global pooling layer.

nfn.layers.Pointwise

Bases: nn.Module

__init__(network_spec, in_channels, out_channels)

Parameters:

Name Type Description Default
network_spec NetworkSpec

Network specification.

required
in_channels int

Number of input channels of weight space features.

required
out_channels int

Number of input channels of weight space features.

required

forward(wsfeat)

Applies a linear NF-Layer to input weight space features. The layer assumes full row and column exchangeability of the weight space features in each layer and ignores interactions between the weight space features. Only last term of Equation 3 is used in constructing this layer.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Input weight space features, where each weight and bias has CinC_{in} channels.

required

Returns:

Type Description
WeightSpaceFeatures

Output weight space features, where each weight and bias has CoutC_{out} channels.

nfn.layers.NPLinear

Bases: nn.Module

__init__(network_spec, in_channels, out_channels, io_embed=False)

Parameters:

Name Type Description Default
network_spec NetworkSpec

Network specification.

required
in_channels int

Number of input channels of weight space features.

required
out_channels int

Number of output channels of weight space features.

required

forward(wsfeat)

Applies a linear NF-Layer to input weight space features. The layer is equivariant to S\mathcal{S}, the neuron permutation (NP) group. See Equation 3 for a complete description.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Input weight space features, where each weight and bias has CinC_{in} channels.

required

Returns:

Name Type Description
WeightSpaceFeatures WeightSpaceFeatures

Output weight space features, where each weight and bias has CoutC_{out} channels.

nfn.layers.HNPLinear

Bases: nn.Module

__init__(network_spec, in_channels, out_channels)

Parameters:

Name Type Description Default
network_spec NetworkSpec

Network specification.

required
in_channels int

Number of input channels of weight space features.

required
out_channels int

Number of output channels of weight space features.

required

forward(wsfeat)

Applies a linear NF-Layer to input weight space features. The layer is equivariant to S~\mathcal{\tilde{S}}, the hidden neuron permutation (HNP) group. See Appendix C for a complete description.

Parameters:

Name Type Description Default
wsfeat WeightSpaceFeatures

Input weight space features, where each weight and bias has CinC_{in} channels.

required

Returns:

Type Description
WeightSpaceFeatures

Output weight space features, where each weight and bias has CoutC_{out} channels.