Source code for lagom.networks.module

import torch
import torch.nn as nn
from torch.nn.utils import vector_to_parameters
from torch.nn.utils import parameters_to_vector


[docs]class Module(nn.Module): r"""Wrap PyTorch nn.module to provide more helper functions. """ def __init__(self, **kwargs): super().__init__() for key, val in kwargs.items(): self.__setattr__(key, val) @property def num_params(self): r"""Returns the total number of parameters in the neural network. """ return sum(param.numel() for param in self.parameters()) @property def num_trainable_params(self): r"""Returns the total number of trainable parameters in the neural network.""" return sum(param.numel() for param in self.parameters() if param.requires_grad) @property def num_untrainable_params(self): r"""Returns the total number of untrainable parameters in the neural network. """ return sum(param.numel() for param in self.parameters() if not param.requires_grad)
[docs] def to_vec(self): r"""Returns the network parameters as a single flattened vector. """ return parameters_to_vector(parameters=self.parameters())
[docs] def from_vec(self, x): r"""Set the network parameters from a single flattened vector. Args: x (Tensor): A single flattened vector of the network parameters with consistent size. """ vector_to_parameters(vec=x, parameters=self.parameters())
[docs] def save(self, f): r"""Save the network parameters to a file. It complies with the `recommended approach for saving a model in PyTorch documentation`_. .. note:: It uses the highest pickle protocol to serialize the network parameters. Args: f (str): file path. .. _recommended approach for saving a model in PyTorch documentation: https://pytorch.org/docs/master/notes/serialization.html#best-practices """ import pickle torch.save(obj=self.state_dict(), f=f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
[docs] def load(self, f): r"""Load the network parameters from a file. It complies with the `recommended approach for saving a model in PyTorch documentation`_. Args: f (str): file path. .. _recommended approach for saving a model in PyTorch documentation: https://pytorch.org/docs/master/notes/serialization.html#best-practices """ self.load_state_dict(torch.load(f))