Neural Functional Networks (NFNs)
Paper (arXiv)
This is a library of PyTorch layers for building permutation equivariant neural functional networks (NFNs). NFNs are equivariant deep learning architectures for processing weight space features, such as the weights or gradients of another neural network. We refer to the layers of an NFN as NF-Layers.
Installation
Simple installation from PyPI:
If you want to edit source or run examples, clone the repository locally. Then run the following commands:Usage
Loading weights as input
NF-Layers operate on WeightSpaceFeatures
. The current NF-Layers are compatible with the weight spaces of simple feedforward MLPs and 2D (image) CNNs. For weight spaces of CNN classifiers we assume there is some global pooling layer (e.g., nn.AdaptiveAvgPool2d(1)
) between the convolution and FC layers. Supporting 1D or 3D CNNs should be possible but is not currently implemented.
To construct WeightSpaceFeatures
from the weights of a Pytorch model, we provide the helper function state_dict_to_tensors()
:
from nfn.common import state_dict_to_tensors
models = [...] # batch of pytorch models
state_dicts = [m.state_dict() for m in models]
wts_and_bs = [state_dict_to_tensors(sd) for sd in state_dicts]
# Collate batch. Can be done automatically by DataLoader.
wts_and_bs = default_collate(wts_and_bs)
wsfeat = WeightSpaceFeatures(*wts_and_bs)
out = nfn(wsfeat) # NFN can now ingest WeightSpaceFeatures
state_dict_to_tensors()
assumes that the state_dict
is an ordered dictionary with keys in order [weight1, bias1, ..., weightL, biasL]
. This is the default behavior if the state_dict
is coming from a feedforward network that is an nn.Sequential
model.
Building NFNs
The NF-Layers are found in nfn.layers
. The main data you need to build an NFN is a network_spec
, which specifies the structure of the weight space you plan to process. If you already have a WeightSpaceFeatures
object as above, you can use network_spec_from_wsfeat
.
from torch import nn
from nfn import layers
from nfn.common import network_spec_from_wsfeat
network_spec = network_spec_from_wsfeat(wsfeat)
nfn_channels = 32
# io_embed: encode the input and output dimensions of the weight space feature
nfn = nn.Sequential(
layers.NPLinear(network_spec, 1, nfn_channels, io_embed=True),
layers.TupleOp(nn.ReLU()),
layers.NPLinear(network_spec, nfn_channels, nfn_channels, io_embed=True),
layers.TupleOp(nn.ReLU()),
layers.HNPPool(network_spec), # pooling layer, for invariance
nn.Flatten(start_dim=-2),
nn.Linear(nfn_channels * layers.HNPPool.get_num_outs(network_spec), 1)
)