3D - Variational Autoencoder implementation Error

I'm trying to implement a 3D variational autoencoder following instructions in the book Generative Deep Learning and taking also some things from here link, those examples are in 2D so I have adapted it.

This is the code i'm using:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input, BatchNormalization, Conv3D, Dense, Flatten, Lambda, Reshape, UpSampling3D
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

## Functional API

input_shape = (32, 32, 32, 1)
z_dim = 3000 

###########  Encoder  ############

enc_in = Input(shape = input_shape, name="encoder_input")

enc_conv1 = tf.keras.layers.Conv3D(filters=8, kernel_size=5, strides=(1,1,1), padding='same', activation="ReLU", kernel_initializer="he_uniform", name="conv_1")(enc_in)  
max1 = tf.keras.layers.MaxPool3D(pool_size=(2,2,2), strides=(2,2,2), padding='valid', name="max_pool_1")(enc_conv1) 

enc_conv2 = tf.keras.layers.Conv3D(filters=16, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="conv_2")(max1)
max2 = tf.keras.layers.MaxPool3D(pool_size=2, name="max_pool_2")(enc_conv2)

enc_fc1 = Dense(units = 512, kernel_initializer = 'he_uniform', activation = 'ReLU')(Flatten()(max2))
   
mu = Dense(units = z_dim, kernel_initializer = 'he_uniform', activation = None, name="mu")(enc_fc1)
   
log_variance = Dense(units = z_dim, kernel_initializer = 'he_uniform', activation = None, name="log_variance")(enc_fc1)   

def sampling(args):
    mu, log_variance = args
    epsilon = K.random_normal(shape=K.shape(mu), mean= 0., stddev= 1.)
    return mu + K.exp(0.5 * log_variance) * epsilon

z = Lambda(sampling, output_shape = (z_dim, ), name="z")([mu, log_variance]) 

encoder = Model(enc_in, [mu, log_variance, z], name="encoder")


##########  Decoder  ############

dec_in = Input(shape = (z_dim, ), name="decoder_input")

dec_fc1 = Dense(units = 512, kernel_initializer = 'he_uniform', activation = 'ReLU')(dec_in)
dec_unflatten = Reshape(target_shape = (8,8,8,1))(dec_fc1)

dec_conv1 = tf.keras.layers.Conv3D(filters=32, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="deconv_1")(dec_unflatten)
ups1 = tf.keras.layers.UpSampling3D(size=(2, 2, 2), name="ups_1")(dec_conv1)

dec_conv2 = tf.keras.layers.Conv3D(filters=16, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="deconv_2")(ups1)
ups2 = tf.keras.layers.UpSampling3D(size=(2, 2, 2), name="ups_2")(dec_conv2)

dec_conv4 = Conv3D(filters = 1, kernel_size = (3, 3, 3), padding = 'same', activation="ReLU", kernel_initializer = 'he_uniform', name='decorder_output')(ups2)

decoder = Model(dec_in, dec_conv4, name="decoder")


model_input = enc_in
model_output = decoder(z)
v_autoencoder = Model(model_input, model_output)
######################

def vae_r_loss(T1_input, model_output):
  r_loss = K.mean(K.square(T1_input - model_output), axis = [1,2,3])
  return r_loss_factor * r_loss


def vae_kl_loss(T1_input, model_output):
  kl_loss = -0.5 * K.sum(1 + log_variance - K.square(mu) - K.exp(log_variance), axis = 1)
  return kl_loss


def vae_loss(T1_input, model_output):
  r_loss = vae_r_loss(T1_input, model_output)
  kl_loss = vae_kl_loss(T1_input, model_output)
  return r_loss + kl_loss

learning_rate=0.001
r_loss_factor = 500

optimizer = keras.optimizers.Adam(lr=learning_rate)

v_autoencoder.compile(optimizer=optimizer, loss = vae_loss, metrics = [vae_r_loss, vae_kl_loss])
v_autoencoder.summary()

But when I try to train the model:

history = v_autoencoder.fit(T1_input, T1_input, epochs=10, validation_split=0.2, shuffle=True)

I get this error:

TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.float32, name=None), name='Placeholder:0', description="created by layer 'tf.cast_4'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras Functional model construction only supports TF API calls that *do* support dispatching, such as `tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layer `call` and calling that layer on this symbolic input/output.

The encoder and the decoder model seem to disconnected and the KerasTensor are not able to flow correctly. Also, the encoder returns three tensors [ mu , variance , z] whereas the decoder only requires z as an input.

Replace these lines,

model_input = enc_in
model_output = decoder(z)
v_autoencoder = Model(model_input, model_output)

with,

vae_input = Input(shape = input_shape, name="encoder_input")
mu_ , variance_ , z_ = encoder( vae_input )
vae_output = decoder( z_ )
v_autoencoder = Model( vae_input , [vae_output , mu_ , variance_] )

So, basically, the model will now output mu_ and variance_ which are the outputs of the encoder, and are required for vae_kl_loss. Earlier, you were using these KerasTensor directly in vae_kl_loss and hence you got the error.

Remove the axis argument from vae_kl_loss,

def vae_kl_loss( mu , log_variance ):
  kl_loss = -0.5 * K.sum(1 + log_variance - K.square(mu) - K.exp(log_variance))
  return kl_loss

Pass mu and variance to vae_kl_loss,

def vae_loss(T1_input, model_output):
  mu = model_output[ 1 ]
  variance = model_output[ 2 ]
  vae_output = model_output[ 0 ]
  r_loss = vae_r_loss(T1_input, vae_output )
  kl_loss = vae_kl_loss( mu , variance )
  return r_loss + kl_loss

Modified code:

import tensorflow as tf
from tensorflow.keras.layers import Input, BatchNormalization, Conv3D, Dense, Flatten, Lambda, Reshape, UpSampling3D
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

## Functional API

input_shape = (32, 32, 32, 1)
z_dim = 3000 

###########  Encoder  ############

enc_in = Input(shape = input_shape, name="encoder_input")

enc_conv1 = tf.keras.layers.Conv3D(filters=8, kernel_size=5, strides=(1,1,1), padding='same', activation="ReLU", kernel_initializer="he_uniform", name="conv_1")(enc_in)  
max1 = tf.keras.layers.MaxPool3D(pool_size=(2,2,2), strides=(2,2,2), padding='valid', name="max_pool_1")(enc_conv1) 

enc_conv2 = tf.keras.layers.Conv3D(filters=16, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="conv_2")(max1)
max2 = tf.keras.layers.MaxPool3D(pool_size=2, name="max_pool_2")(enc_conv2)

enc_fc1 = Dense(units = 512, kernel_initializer = 'he_uniform', activation = 'ReLU')(Flatten()(max2))
   
mu = Dense(units = z_dim, kernel_initializer = 'he_uniform', activation = None, name="mu")(enc_fc1)
   
log_variance = Dense(units = z_dim, kernel_initializer = 'he_uniform', activation = None, name="log_variance")(enc_fc1)   

def sampling(args):
    mu, log_variance = args
    epsilon = K.random_normal(shape=K.shape(mu), mean= 0., stddev= 1.)
    return mu + K.exp(0.5 * log_variance) * epsilon

z = Lambda(sampling, output_shape = (z_dim, ), name="z")([mu, log_variance]) 

encoder = Model(enc_in, [mu, log_variance, z], name="encoder")


##########  Decoder  ############

dec_in = Input(shape = (z_dim, ), name="decoder_input")

dec_fc1 = Dense(units = 512, kernel_initializer = 'he_uniform', activation = 'ReLU')(dec_in)
dec_unflatten = Reshape(target_shape = (8,8,8,1))(dec_fc1)

dec_conv1 = tf.keras.layers.Conv3D(filters=32, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="deconv_1")(dec_unflatten)
ups1 = tf.keras.layers.UpSampling3D(size=(2, 2, 2), name="ups_1")(dec_conv1)

dec_conv2 = tf.keras.layers.Conv3D(filters=16, kernel_size=3, padding='same', activation="ReLU", kernel_initializer="he_uniform", name="deconv_2")(ups1)
ups2 = tf.keras.layers.UpSampling3D(size=(2, 2, 2), name="ups_2")(dec_conv2)

dec_conv4 = Conv3D(filters = 1, kernel_size = (3, 3, 3), padding = 'same', activation="ReLU", kernel_initializer = 'he_uniform', name='decorder_output')(ups2)

decoder = Model(dec_in, dec_conv4, name="decoder")

vae_input = Input(shape = input_shape, name="encoder_input")
mu_ , variance_ , z_ = encoder( vae_input )
vae_output = decoder( z_ )
v_autoencoder = Model( vae_input , [vae_output , mu_ , variance_] )
######################

def vae_r_loss(T1_input, model_output):
  r_loss = K.mean(K.square(T1_input - model_output), axis = [1,2,3])
  return r_loss_factor * r_loss


def vae_kl_loss( mu , log_variance ):
  kl_loss = -0.5 * K.sum(1 + log_variance - K.square(mu) - K.exp(log_variance))
  return kl_loss


def vae_loss(T1_input, model_output):
  mu = model_output[ 1 ]
  variance = model_output[ 2 ]
  vae_output = model_output[ 0 ]
  r_loss = vae_r_loss(T1_input, vae_output )
  kl_loss = vae_kl_loss( mu , variance )
  return r_loss + kl_loss

learning_rate=0.001
r_loss_factor = 500

optimizer = tf.keras.optimizers.Adam(lr=learning_rate)

v_autoencoder.compile(optimizer=optimizer, loss = vae_loss, metrics = [vae_r_loss, vae_kl_loss])
v_autoencoder.summary()