Skip to content
Advertisement

How to use multiple inputs in Tensorflow 2.x Keras Custom Layer?

I’m trying to use multiple inputs in custom layers in Tensorflow-Keras. Usage can be anything, right now it is defined as multiplying the mask with the image. I’ve search SO and the only answer I could find was for TF 1.x so it didn’t do any good.

class mul(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # I've added pass because this is the simplest form I can come up with.
        pass
          
    def call(self, inputs):
        # magic happens here and multiplications occur
        return(Z)

Advertisement

Answer

EDIT: Since TensorFlow v2.3/2.4, the contract is to use a list of inputs to the call method. For keras (not tf.keras) I think the answer below still applies.

Implementing multiple inputs is done in the call method of your class, there are two alternatives:

  • List input, here the inputs parameter is expected to be a list containing all the inputs, the advantage here is that it can be variable size. You can index the list, or unpack arguments using the = operator:

      def call(self, inputs):
          Z = inputs[0] * inputs[1]
    
          #Alternate
          input1, input2 = inputs
          Z = input1 * input2
    
          return Z
    
  • Multiple input parameters in the call method, works but then the number of parameters is fixed when the layer is defined:

      def call(self, input1, input2):
          Z = input1 * input2
    
          return Z
    

Whatever method you choose to implement this depends if you need fixed size or variable sized number of arguments. Of course each method changes how the layer has to be called, either by passing a list of arguments, or by passing arguments one by one in the function call.

You can also use *args in the first method to allow for a call method with a variable number of arguments, but overall keras’ own layers that take multiple inputs (like Concatenate and Add) are implemented using lists.

User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement