How to set the input of a keras subclass model in tensorflow?

Solution 1:

Something like that?

model_ = SubModel()
inputs = tf.keras.input(shape=(100,))
outputs = model_(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

Solution 2:

I ended up giving up on keras.Model subclassing. It was too tricky and I was getting errors about input shape.

I wanted to be able to use .fit() directly on my custom class model objects. For this purpose, an easy method I found was to implement the builtin __getattr__ method (more info can be found in official Python doc). The class implementation I use:

from tensorflow.keras import Input, layers, Model

class SubModel():
    def __init__(self):
        self.model = self.get_model()

    def get_model(self):
        # here we use the usual Keras functional API
        x = Input(shape=(24, 24, 3))
        y = layers.Conv2D(28, 3, strides=1)(x)
        return Model(inputs=[x], outputs=[y])

    def __getattr__(self, name):
        """
        This method enables to access an attribute/method of self.model.
        Thus, any method of keras.Model() can be used transparently from a SubModel object
        """
        return getattr(self.model, name)


if __name__ == '__main__':
    submodel = SubModel()
    submodel.fit(data, labels, ...)  # underlyingly calls SubModel.model.fit()