Source code for lagom.networks.init

import torch.nn as nn

[docs]def ortho_init(module, nonlinearity=None, weight_scale=1.0, constant_bias=0.0): r"""Applies orthogonal initialization for the parameters of a given module. Args: module (nn.Module): A module to apply orthogonal initialization over its parameters. nonlinearity (str, optional): Nonlinearity followed by forward pass of the module. When nonlinearity is not ``None``, the gain will be calculated and :attr:`weight_scale` will be ignored. Default: ``None`` weight_scale (float, optional): Scaling factor to initialize the weight. Ignored when :attr:`nonlinearity` is not ``None``. Default: 1.0 constant_bias (float, optional): Constant value to initialize the bias. Default: 0.0 .. note:: Currently, the only supported :attr:`module` are elementary neural network layers, e.g. nn.Linear, nn.Conv2d, nn.LSTM. The submodules are not supported. Example:: >>> a = nn.Linear(2, 3) >>> ortho_init(a) """ if nonlinearity is not None: gain = nn.init.calculate_gain(nonlinearity) else: gain = weight_scale if isinstance(module, (nn.RNNBase, nn.RNNCellBase)): for name, param in module.named_parameters(): if 'weight_' in name: nn.init.orthogonal_(param, gain=gain) elif 'bias_' in name: nn.init.constant_(param, constant_bias) else: # other modules with single .weight and .bias nn.init.orthogonal_(module.weight, gain=gain) nn.init.constant_(module.bias, constant_bias)