Unfolding Keras model summary for two joint sequential models

Solution 1:

There's a parameter called expand_nested=True from tf 2.7 for a model summary method that will disclose inner nested loop layers (issue, pr). But as you're using a relatively older version, tf 2.4, you can adopt my following workaround,

def summary_plus(layer, i=0):
    if hasattr(layer, 'layers'):
        if i != 0: 
            layer.summary()
        for l in layer.layers:
            i += 1
            summary_plus(l, i=i)

summary_plus(model) # OK 

Solution 2:

I would suggest trying to set the expanded_nested parameter of model.summary() to True, which will expand the nested models as stated in the docs (Does not exist in older TF versions). It is not the prettiest output, but it does the job:

print(model.summary(expand_nested=True))
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sequential_3 (Sequential)   (None, 2048)              14714688  
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| conv2d_13 (Conv2D)        (None, 64, 64, 64)        1792      |
|                                                               |
| conv2d_14 (Conv2D)        (None, 64, 64, 64)        36928     |
|                                                               |
| max_pooling2d_2 (MaxPooling  (None, 32, 32, 64)     0         |
| 2D)                                                           |
|                                                               |
| conv2d_15 (Conv2D)        (None, 32, 32, 128)       73856     |
|                                                               |
| conv2d_16 (Conv2D)        (None, 32, 32, 128)       147584    |
|                                                               |
| max_pooling2d_3 (MaxPooling  (None, 16, 16, 128)    0         |
| 2D)                                                           |
|                                                               |
| conv2d_17 (Conv2D)        (None, 16, 16, 256)       295168    |
|                                                               |
| conv2d_18 (Conv2D)        (None, 16, 16, 256)       590080    |
|                                                               |
| conv2d_19 (Conv2D)        (None, 16, 16, 256)       590080    |
|                                                               |
| block3_pool (MaxPooling2D)  (None, 8, 8, 256)       0         |
|                                                               |
| conv2d_20 (Conv2D)        (None, 8, 8, 512)         1180160   |
|                                                               |
| conv2d_21 (Conv2D)        (None, 8, 8, 512)         2359808   |
|                                                               |
| conv2d_22 (Conv2D)        (None, 8, 8, 512)         2359808   |
|                                                               |
| block4_pool (MaxPooling2D)  (None, 4, 4, 512)       0         |
|                                                               |
| conv2d_23 (Conv2D)        (None, 4, 4, 512)         2359808   |
|                                                               |
| conv2d_24 (Conv2D)        (None, 4, 4, 512)         2359808   |
|                                                               |
| conv2d_25 (Conv2D)        (None, 4, 4, 512)         2359808   |
|                                                               |
| block5_pool (MaxPooling2D)  (None, 2, 2, 512)       0         |
|                                                               |
| flatten (Flatten)         (None, 2048)              0         |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
 sequential_4 (Sequential)   (None, 64, 64, 3)         34715075  
|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|
| dense_1 (Dense)           (None, 16384)             33570816  |
|                                                               |
| batch_normalization_4 (Batc  (None, 16384)          65536     |
| hNormalization)                                               |
|                                                               |
| activation_5 (Activation)  (None, 16384)            0         |
|                                                               |
| reshape_1 (Reshape)       (None, 8, 8, 256)         0         |
|                                                               |
| dropout_1 (Dropout)       (None, 8, 8, 256)         0         |
|                                                               |
| up_sampling2d_3 (UpSampling  (None, 16, 16, 256)    0         |
| 2D)                                                           |
|                                                               |
| conv2d_transpose_4 (Conv2DT  (None, 16, 16, 128)    819328    |
| ranspose)                                                     |
|                                                               |
| batch_normalization_5 (Batc  (None, 16, 16, 128)    512       |
| hNormalization)                                               |
|                                                               |
| activation_6 (Activation)  (None, 16, 16, 128)      0         |
|                                                               |
| up_sampling2d_4 (UpSampling  (None, 32, 32, 128)    0         |
| 2D)                                                           |
|                                                               |
| conv2d_transpose_5 (Conv2DT  (None, 32, 32, 64)     204864    |
| ranspose)                                                     |
|                                                               |
| batch_normalization_6 (Batc  (None, 32, 32, 64)     256       |
| hNormalization)                                               |
|                                                               |
| activation_7 (Activation)  (None, 32, 32, 64)       0         |
|                                                               |
| conv2d_transpose_6 (Conv2DT  (None, 32, 32, 32)     51232     |
| ranspose)                                                     |
|                                                               |
| batch_normalization_7 (Batc  (None, 32, 32, 32)     128       |
| hNormalization)                                               |
|                                                               |
| activation_8 (Activation)  (None, 32, 32, 32)       0         |
|                                                               |
| up_sampling2d_5 (UpSampling  (None, 64, 64, 32)     0         |
| 2D)                                                           |
|                                                               |
| conv2d_transpose_7 (Conv2DT  (None, 64, 64, 3)      2403      |
| ranspose)                                                     |
|                                                               |
| activation_9 (Activation)  (None, 64, 64, 3)        0         |
¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
=================================================================
Total params: 49,429,763
Trainable params: 49,396,547
Non-trainable params: 33,216
_________________________________________________________________
None

For older TF versions, just run print(model.layers[0].summary()) and print(model.layers[1].summary()).