I am trying to implement a simple LSTMCell without the “fancy kwargs” defaultly implemented in the tf.keras.layers.LSTMCell class, following a schematic model like this. It doesn’t really have a direct purpose, I would just like to practice implementing a more complex RNNCell than the one described here in the Examples section. My code is the following:
from keras import Input from keras.layers import Layer, RNN from keras.models import Model import keras.backend as K class CustomLSTMCell(Layer): def __init__(self, units, **kwargs): self.state_size = units super(CustomLSTMCell, self).__init__(**kwargs) def build(self, input_shape): self.forget_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]), initializer='uniform', name='forget_w') self.forget_b = self.add_weight(shape=(self.state_size,), initializer='uniform', name='forget_b') self.input_w1 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]), initializer='uniform', name='input_w1') self.input_b1 = self.add_weight(shape=(self.state_size,), initializer='uniform', name='input_b1') self.input_w2 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]), initializer='uniform', name='input_w2') self.input_b2 = self.add_weight(shape=(self.state_size,), initializer='uniform', name='input_b2') self.output_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]), initializer='uniform', name='output_w') self.output_b = self.add_weight(shape=(self.state_size,), initializer='uniform', name='output_b') self.built = True def merge_with_state(self, inputs): self.stateH = K.concatenate([self.stateH, inputs], axis=-1) def forget_gate(self): forget = K.dot(self.forget_w, self.stateH) + self.forget_b forget = K.sigmoid(forget) self.stateC = self.stateC * forget def input_gate(self): candidate = K.dot(self.input_w1, self.stateH) + self.input_b1 candidate = K.tanh(candidate) amount = K.dot(self.input_w2, self.stateH) + self.input_b2 amount = K.tanh(amount) self.stateC = self.stateC + amount * candidate def output_gate(self): self.stateH = K.dot(self.output_w, self.stateH) + self.output_b self.stateH = K.sigmoid(self.stateH) self.stateH = self.stateH * K.tanh(self.stateC) def call(self, inputs, states): self.stateH = states[0] self.stateC = states[1] self.merge_with_state(inputs) self.forget_gate() self.input_gate() self.output_gate() return self.stateH, [self.stateH, self.stateC] # Testing inp = Input(shape=(None, 3)) lstm = RNN(CustomLSTMCell(10))(inp) model = Model(inputs=inp, outputs=lstm) inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]] pred = model.predict(inp_value) print(pred)
However, when I tried to test it an exception was raised with the following message:
IndexError: tuple index out of range
in the call
function at the line where I set the value for self.stateC
. Here, I thought that initially the states
argument of the call
function is a tensor and not a list of tensors, so this is why I get an error. So I added a self.already_called = False
line to the classes __init__
and the following segment to the call
function:
if not self.already_called: self.stateH = K.ones(self.state_size) self.stateC = K.ones(self.state_size) self.already_called = True else: self.stateH = states[0] self.stateC = states[1]
hoping that it will eliminate the problem. This resulted in another error at the merge_with_state
function:
ValueError: Shape must be rank 1 but is rank 2 for 'rnn_1/concat' (op: 'ConcatV2') with input shapes: [10], [?,3], [].
which I genuinely do not get, since the RNN layer should only “show” the CustomLSTMCell tensors with shape (3) and not (None, 3), since axis 0 is the axis it should iterate along. At this point I was convinced that I am doing something really wrong and should ask the community for help. Basically my question is: what is wrong with my code and if “almost everything”, then how should I implement an LSTMCell from scratch?
Advertisement
Answer
Ok, so it seems that I managed to fix the problem. It turns out that it is always useful to read the documentation, in this case the docs for the RNN class. First, the already_called
attribute is unnecessary, because the problem lies in the first line of the __init__
function: the state_size
attribute should be a list of integers and not only one integer, like this: self.state_size = [units, units]
(since we need two states for an LSTM of size units
and not one). When I corrected it I got a different error: the tensors are not compatible in dimension in the forget_gate
for addition. This happened because the RNN sees the whole batch at once and not each element in the batch separately (thus the None
shape at axis 0). The correction for it is to add an extra dimension to each tensor of size 1 at axis 0 like this:
self.forget_w = self.add_weight(shape=(1, self.state_size, self.state_size + input_shape[-1]), initializer='uniform', name='forget_w')
and instead of dot products I had to use the K.batch_dot
function. So the whole, working code is the following:
from keras import Input from keras.layers import Layer, RNN from keras.models import Model import keras.backend as K class CustomLSTMCell(Layer): def __init__(self, units, **kwargs): self.state_size = [units, units] super(CustomLSTMCell, self).__init__(**kwargs) def build(self, input_shape): self.forget_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]), initializer='uniform', name='forget_w') self.forget_b = self.add_weight(shape=(1, self.state_size[0]), initializer='uniform', name='forget_b') self.input_w1 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]), initializer='uniform', name='input_w1') self.input_b1 = self.add_weight(shape=(1, self.state_size[0]), initializer='uniform', name='input_b1') self.input_w2 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]), initializer='uniform', name='input_w2') self.input_b2 = self.add_weight(shape=(1, self.state_size[0],), initializer='uniform', name='input_b2') self.output_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]), initializer='uniform', name='output_w') self.output_b = self.add_weight(shape=(1, self.state_size[0],), initializer='uniform', name='output_b') self.built = True def merge_with_state(self, inputs): self.stateH = K.concatenate([self.stateH, inputs], axis=-1) def forget_gate(self): forget = K.batch_dot(self.forget_w, self.stateH) + self.forget_b forget = K.sigmoid(forget) self.stateC = self.stateC * forget def input_gate(self): candidate = K.batch_dot(self.input_w1, self.stateH) + self.input_b1 candidate = K.tanh(candidate) amount = K.batch_dot(self.input_w2, self.stateH) + self.input_b2 amount = K.sigmoid(amount) self.stateC = self.stateC + amount * candidate def output_gate(self): self.stateH = K.batch_dot(self.output_w, self.stateH) + self.output_b self.stateH = K.sigmoid(self.stateH) self.stateH = self.stateH * K.tanh(self.stateC) def call(self, inputs, states): self.stateH = states[0] self.stateC = states[1] self.merge_with_state(inputs) self.forget_gate() self.input_gate() self.output_gate() return self.stateH, [self.stateH, self.stateC] inp = Input(shape=(None, 3)) lstm = RNN(CustomLSTMCell(10))(inp) model = Model(inputs=inp, outputs=lstm) inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]] pred = model.predict(inp_value) print(pred)
Edit: In the question I made a mistake with respect to the model linked and used a tanh function in the input_gate
for amount
instead of a sigmoid. Here I edited it in the code, so it is correct now.