I wrote a custom Tree-RNN-CELL that can handle several different inputs when they are provided as a tuple.
... treeCell3_1 = TreeRNNCell(units=encodingBitLength, num_children=2) RNNC = layers.RNN(treeCell3_1, return_state=True, return_sequences=True) h_c_batch, h, c = RNNC( inputs=(h_batch2_1, c_batch2_1, h_batch2_2, c_batch2_2))
This is working fine, but now I wanted to put it together in a submodel, so that i can sum the 4 lines up in 2 lines and to have a better overview ( the tree gets big so its worth it)
class TreeCellModel(tf.keras.Model): def __init__(self, units, num_children): super().__init__() self.units = units self.num_children = num_children self.treeCell = TreeRNNCell(units=units, num_children=num_children) self.treeRNN = layers.RNN(self.treeCell, return_state=True, return_sequences=True) def call(self, inputs, **kwargs): h_c_batch, h, c = self.treeRNN(inputs=(inputs)) h_batch, c_batch = AddCellStatesLayer(units=self.units)(h_c_batch) return h_batch, c_batch treeCell2_1 = TreeCellModel(units=encodingBitLength, num_children=2) h_batch2_1, c_batch2_1 = treeCell1_1(inputs=(h_batch1_1, c_batch1_1, h_batch1_2, c_batch1_2))
But now i get this error: ValueError: Layer rnn expects 1 input(s), but it received 4 input tensors. Inputs received: [<tf.Tensor ‘h_batch1_1’ shape=(1, 5, 19) dtype=float32>, <tf.Tensor ‘c_batch1_1’ shape=(1, 5, 19) dtype=float32>, <tf.Tensor ‘h_batch1_2’ shape=(1, 5, 19) dtype=float32>, <tf.Tensor ‘c_batch1_2’ shape=(1, 5, 19) dtype=float32>]
I checked the error already, and normally it gets fixed when using a tuple around the inputs. But thats what I’m already doing. I also doublechecked by outputting the type of “inputs” and it is a tuple.
Help please.
Advertisement
Answer
RNN
is expecting “one” input, then you must give it “one” input. The implementation of your cell will probably not matter.
You can change your code to join the 4 tensors together and separate them inside your cell. This is possible because all your tensors have the same shape.
You could use a:
joined_inputs = layers.Lambda(lambda x: keras.backend.stack(x, axis=-1))([input1, input2, input3, input4])
Then your cell should be able to separate the inputs:
def call(self, inputTensor .....): inputs = [inputTensor[:,:,:,i] for i in range(4)] ....