TensorFlow layer subclass input shape
Solution 1:
Not sure what exactly you want to do, but I would recommend using Tensorflow
operations only. Here is an example:
import tensorflow as tf
class IIR(tf.keras.layers.Layer):
def __init__(self, input_dim):
super(IIR, self).__init__()
self.input_dim = input_dim
self.b0 = tf.Variable(tf.random.uniform((1,), minval=-1, maxval=1))
self.b1 = tf.Variable(tf.random.uniform((1,), minval=-1, maxval=1))
self.b2 = tf.Variable(tf.random.uniform((1,), minval=-1, maxval=1))
self.a1 = tf.Variable(tf.random.uniform((1,), minval=-1, maxval=1))
self.a2 = tf.Variable(tf.random.uniform((1,), minval=-1, maxval=1))
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
output_sequence = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, clear_after_read=False)
output_sequence = output_sequence.write(0, inputs[:, 0])
output_sequence = output_sequence.write(1, inputs[:, 1])
output_sequence = output_sequence.write(2, inputs[:, 2])
for i in range(2, self.input_dim):
output_sequence = output_sequence.write(i, self.b0*inputs[:, i] + self.b1*inputs[:, i-1]
+ self.b2*inputs[:, i-2] - self.a1*output_sequence.read(i-1)
- self.a2*output_sequence.read(i-2))
result = output_sequence.stack()
return tf.reshape(result, tf.shape(inputs))
iir = IIR(input_dim=60)
tf.print(iir(tf.random.normal((2, 60))).shape)
iir = IIR(input_dim=60)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=60))
model.add(IIR(input_dim=60))
model.add(tf.keras.layers.Dense(8, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
print(model.summary())
TensorShape([2, 60])
Model: "sequential_21"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
iir_80 (IIR) (None, 60) 5
dense_20 (Dense) (None, 8) 488
dense_21 (Dense) (None, 1) 9
=================================================================
Total params: 502
Trainable params: 502
Non-trainable params: 0
_________________________________________________________________
None