Collecting features from network.foward() in TensorFlow

To answer your question, I just need to ensure that you understand your original torch code properly. So, here's your workflow

class LeNet(nn.Module):
     def forward:
         # few bunch of layers 
         return output

     def forward_features:
         # same as forward function 
         return [each layer output] 

Now, next, you use the torch_get_function method and retrieve all layers output from the forward_features function that is defined in your model. The torch_get_function gives a total of 4 outputs as a list and you pick only the first feature and concate across the batches in the end.

def torch_get_function(network, loader):
    features = []
    for batch_idx, (inputs, targets) in enumerate(loader):

        print('0', network.forward_features(inputs)[0].shape)
        print('1', network.forward_features(inputs)[1].shape)
        print('2', network.forward_features(inputs)[2].shape)
        print('3', network.forward_features(inputs)[3].shape)
        print()

        features.append([f... for f in network.forward_features(inputs)])
    return [np.concatenate(list(zip(*features))[i]) for i in range(len(features[0]))]

for epoch in epochs:
    dataset = torchvision.datasets.MNIST...
    dataset = torch.utils.data.Subset(dataset, list(range(0, 1000)))
    functloader = torch.utils.data.DataLoader(...)

    # for x , y in functloader:
    #     print('a ', x.shape, y.shape) 
    # a  torch.Size([100, 1, 28, 28]) torch.Size([100])
        
    activs = torch_get_function(net, functloader)
    print(activs[0].shape)
    break

That's why if when I ran your code, I got

# These are the 4 output that returned by forward_features(inputs)
0 torch.Size([100, 10, 12, 12])
1 torch.Size([100, 320])
2 torch.Size([100, 50])
3 torch.Size([100, 10])

...

# In the return statement of forward_features -
# You take only the first index feature and concate across batches.
(1000, 10, 12, 12)

So, the input size of your model is (batch_size, 1, 28, 28) and the final output is like (1000, 10, 12, 12).


Let's do the same in tensorflow, step by step.

import numpy as np 
from tqdm import tqdm 

import tensorflow as tf 
from tensorflow import keras 
from tensorflow.keras.layers import (Conv2D, Dropout, MaxPooling2D, 
                                     Dense, Flatten)

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
dataset = dataset.shuffle(buffer_size=1024).batch(100)

# it's like torch.utils.data.Subset
dataset = dataset.take(1000)
dataset
<TakeDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.uint8)>

Let's now build the model. To make it familiar to you, I'm writing in sub-class API.

class LeNet(keras.Model):
    def __init__(self, num_classes, input_size=28):
        super(LeNet, self).__init__()
        self.conv1 = Conv2D(10, (5, 5))
        self.conv2 = Conv2D(20, (5, 5))
        self.conv2_drop = Dropout(rate=0.5)
        self.fc1 = Dense(50)
        self.fc2 = Dense(num_classes)

    def call(self, inputs, training=None):
        x1 = tf.nn.relu(MaxPooling2D(2)(self.conv1(inputs)))
        x2 = tf.nn.relu(MaxPooling2D(2)(self.conv2_drop(self.conv2(x1))))
        x2 = Flatten()(x2)
        x3 = tf.nn.relu(self.fc1(x2))
        x4 = tf.nn.softmax(self.fc2(x3), axis=1)

        # in tf/keras, when we will call model.fit / model.evaluate 
        # to train the model only x4 will return 
        if training:
            x4
        else: # but when model(input)/model.predict(), we can return many :)
            return [x1, x2, x3, x4]

lenet = LeNet(10)
lenet.build(input_shape=(None, 28, 28, 1))

Get the desired features

features = []

for input, target in tqdm(dataset):
    # lenet(...) will give 4 output as we model
    # but as we're interested on the first index feature... 
    features.append(lenet(input, training=False)[0])

print(len(features))
features = np.concatenate(features, axis=0)
features.shape
(10000, 12, 12, 10)

In tensorflow, the channel axis is default set to last, as opposed to the torch. In torch, you received (1000, 10, 12, 12) and in tensorflow, it gives you (10000, 12, 12, 10) but you can change it of course, (how). Here is the working colab.