Source code for lagom.networks.ln_rnn

from typing import List, Tuple

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
import torch.jit as jit


[docs]class LayerNormLSTMCell(jit.ScriptModule): def __init__(self, input_size, hidden_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = nn.Parameter(torch.randn(4*hidden_size, input_size)) self.weight_hh = nn.Parameter(torch.randn(4*hidden_size, hidden_size)) # The layernorms provide learnable biases self.layernorm_i = nn.LayerNorm(4*hidden_size) self.layernorm_h = nn.LayerNorm(4*hidden_size) self.layernorm_c = nn.LayerNorm(hidden_size) @jit.script_method def forward(self, input, state): # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] hx, cx = state igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) gates = igates + hgates ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ingate = torch.sigmoid(ingate) forgetgate = torch.sigmoid(forgetgate) cellgate = torch.tanh(cellgate) outgate = torch.sigmoid(outgate) cy = self.layernorm_c(forgetgate*cx + ingate*cellgate) hy = outgate*torch.tanh(cy) return hy, (hy, cy)
[docs]class LSTMLayer(jit.ScriptModule): def __init__(self, cell, *cell_args): super().__init__() self.cell = cell(*cell_args) @jit.export def forward_packed(self, input, state): # type: (PackedSequence, Tuple[Tensor, Tensor]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] data, batch_sizes, sorted_indices, unsorted_indices = input state = (state[0].index_select(0, sorted_indices), state[1].index_select(0, sorted_indices)) outputs = [] for batch_size, x in zip(batch_sizes, data.split(batch_sizes.numpy().tolist(), dim=0)): assert batch_size == x.shape[0] state = (state[0][:batch_size, ...], state[1][:batch_size, ...]) out, state = self.cell(x, state) outputs += [out] outputs = PackedSequence(torch.cat(outputs, 0), batch_sizes, sorted_indices, unsorted_indices) return outputs, state @jit.export def forward_tensor(self, input, state): # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] inputs = input.unbind(0) outputs = [] for i in range(len(inputs)): out, state = self.cell(inputs[i], state) outputs += [out] return torch.stack(outputs), state
[docs] @jit.ignore def forward(self, input, state): if isinstance(input, PackedSequence): return self.forward_packed(input, state) else: return self.forward_tensor(input, state)
[docs]class StackedLSTM(jit.ScriptModule): __constants__ = ['layers'] # Necessary for iterating through self.layers def __init__(self, num_layers, layer, first_layer_args, other_layer_args): super().__init__() self.layers = nn.ModuleList([layer(*first_layer_args)] + [layer(*other_layer_args) for _ in range(num_layers - 1)]) @jit.export def forward_packed(self, input, states): # type: (PackedSequence, List[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for rnn_layer in self.layers: state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] i += 1 return output, output_states @jit.export def forward_tensor(self, input, states): # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] # List[LSTMState]: One state per layer output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) output = input # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 i = 0 for rnn_layer in self.layers: state = states[i] output, out_state = rnn_layer(output, state) output_states += [out_state] i += 1 return output, output_states
[docs] @jit.ignore def forward(self, input, states): if isinstance(input, PackedSequence): return self.forward_packed(input, states) else: return self.forward_tensor(input, states)
[docs]def make_lnlstm(input_size, hidden_size, num_layers=1): return StackedLSTM(num_layers, LSTMLayer, first_layer_args=[LayerNormLSTMCell, input_size, hidden_size], other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])