Skip to content
Advertisement

Implementing a minimal LSTMCell in Keras using RNN and Layer classes

I am trying to implement a simple LSTMCell without the “fancy kwargs” defaultly implemented in the tf.keras.layers.LSTMCell class, following a schematic model like this. It doesn’t really have a direct purpose, I would just like to practice implementing a more complex RNNCell than the one described here in the Examples section. My code is the following:

from keras import Input
from keras.layers import Layer, RNN
from keras.models import Model
import keras.backend as K

class CustomLSTMCell(Layer):

    def __init__(self, units, **kwargs):
        self.state_size = units
        super(CustomLSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):

        self.forget_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='forget_w')
        self.forget_b = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='forget_b')

        self.input_w1 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='input_w1')
        self.input_b1 = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='input_b1')
        self.input_w2 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='input_w2')
        self.input_b2 = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='input_b2')

        self.output_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='output_w')
        self.output_b = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='output_b')

        self.built = True

    def merge_with_state(self, inputs):
        self.stateH = K.concatenate([self.stateH, inputs], axis=-1)

    def forget_gate(self):
        forget = K.dot(self.forget_w, self.stateH) + self.forget_b
        forget = K.sigmoid(forget)
        self.stateC = self.stateC * forget

    def input_gate(self):
        candidate = K.dot(self.input_w1, self.stateH) + self.input_b1
        candidate = K.tanh(candidate)

        amount = K.dot(self.input_w2, self.stateH) + self.input_b2
        amount = K.tanh(amount)

        self.stateC = self.stateC + amount * candidate

    def output_gate(self):
        self.stateH = K.dot(self.output_w, self.stateH) + self.output_b
        self.stateH = K.sigmoid(self.stateH)

        self.stateH = self.stateH * K.tanh(self.stateC)

    def call(self, inputs, states):

        self.stateH = states[0]
        self.stateC = states[1]

        self.merge_with_state(inputs)
        self.forget_gate()
        self.input_gate()
        self.output_gate()

        return self.stateH, [self.stateH, self.stateC]

# Testing
inp = Input(shape=(None, 3))
lstm = RNN(CustomLSTMCell(10))(inp)

model = Model(inputs=inp, outputs=lstm)
inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]]
pred = model.predict(inp_value)
print(pred)

However, when I tried to test it an exception was raised with the following message:

IndexError: tuple index out of range

in the call function at the line where I set the value for self.stateC. Here, I thought that initially the statesargument of the call function is a tensor and not a list of tensors, so this is why I get an error. So I added a self.already_called = False line to the classes __init__ and the following segment to the call function:

 if not self.already_called:
        self.stateH = K.ones(self.state_size)
        self.stateC = K.ones(self.state_size)
        self.already_called = True
    else:
        self.stateH = states[0]
        self.stateC = states[1]

hoping that it will eliminate the problem. This resulted in another error at the merge_with_state function:

 ValueError: Shape must be rank 1 but is rank 2 for 'rnn_1/concat' (op: 'ConcatV2') with input shapes: [10], [?,3], [].

which I genuinely do not get, since the RNN layer should only “show” the CustomLSTMCell tensors with shape (3) and not (None, 3), since axis 0 is the axis it should iterate along. At this point I was convinced that I am doing something really wrong and should ask the community for help. Basically my question is: what is wrong with my code and if “almost everything”, then how should I implement an LSTMCell from scratch?

Advertisement

Answer

Ok, so it seems that I managed to fix the problem. It turns out that it is always useful to read the documentation, in this case the docs for the RNN class. First, the already_called attribute is unnecessary, because the problem lies in the first line of the __init__ function: the state_size attribute should be a list of integers and not only one integer, like this: self.state_size = [units, units] (since we need two states for an LSTM of size units and not one). When I corrected it I got a different error: the tensors are not compatible in dimension in the forget_gate for addition. This happened because the RNN sees the whole batch at once and not each element in the batch separately (thus the None shape at axis 0). The correction for it is to add an extra dimension to each tensor of size 1 at axis 0 like this:

 self.forget_w = self.add_weight(shape=(1, self.state_size, self.state_size + input_shape[-1]),
                                initializer='uniform',
                                name='forget_w')

and instead of dot products I had to use the K.batch_dot function. So the whole, working code is the following:

 from keras import Input
 from keras.layers import Layer, RNN
 from keras.models import Model
 import keras.backend as K

 class CustomLSTMCell(Layer):

     def __init__(self, units, **kwargs):
         self.state_size = [units, units]
         super(CustomLSTMCell, self).__init__(**kwargs)

     def build(self, input_shape):

         self.forget_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='forget_w')
         self.forget_b = self.add_weight(shape=(1, self.state_size[0]),
                                         initializer='uniform',
                                         name='forget_b')

         self.input_w1 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='input_w1')
         self.input_b1 = self.add_weight(shape=(1, self.state_size[0]),
                                         initializer='uniform',
                                         name='input_b1')
         self.input_w2 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='input_w2')
         self.input_b2 = self.add_weight(shape=(1, self.state_size[0],),
                                         initializer='uniform',
                                         name='input_b2')

         self.output_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='output_w')
         self.output_b = self.add_weight(shape=(1, self.state_size[0],),
                                         initializer='uniform',
                                         name='output_b')

         self.built = True

     def merge_with_state(self, inputs):
         self.stateH = K.concatenate([self.stateH, inputs], axis=-1)

     def forget_gate(self):        
         forget = K.batch_dot(self.forget_w, self.stateH) + self.forget_b
         forget = K.sigmoid(forget)
         self.stateC = self.stateC * forget

     def input_gate(self):
         candidate = K.batch_dot(self.input_w1, self.stateH) + self.input_b1
         candidate = K.tanh(candidate)

         amount = K.batch_dot(self.input_w2, self.stateH) + self.input_b2
         amount = K.sigmoid(amount)

         self.stateC = self.stateC + amount * candidate

     def output_gate(self):
         self.stateH = K.batch_dot(self.output_w, self.stateH) + self.output_b
         self.stateH = K.sigmoid(self.stateH)

         self.stateH = self.stateH * K.tanh(self.stateC)

     def call(self, inputs, states):

         self.stateH = states[0]
         self.stateC = states[1]

         self.merge_with_state(inputs)
         self.forget_gate()
         self.input_gate()
         self.output_gate()

         return self.stateH, [self.stateH, self.stateC]

 inp = Input(shape=(None, 3))
 lstm = RNN(CustomLSTMCell(10))(inp)

 model = Model(inputs=inp, outputs=lstm)
 inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]]
 pred = model.predict(inp_value)
 print(pred)

Edit: In the question I made a mistake with respect to the model linked and used a tanh function in the input_gate for amount instead of a sigmoid. Here I edited it in the code, so it is correct now.

User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement