Manipulating each batch individually in a tensorflow dataset

Consider the code below:

import tensorflow as tf
import numpy as np
 
simple_data_samples = np.array([
         [1, 1, 1, 7, -1],
         [2, -2, 2, -2, -2],
         [3, 3, 3, -3, -3],
         [-4, 4, 4, -4, -4],
         [5, 5, 5, -5, -5],
         [6, 6, 6, -4, -1],
         [7, 7, 8, -7, -70],
         [8, 8, 8, -8, -8],
         [9, 4, 9, -9, -9],
         [10, 10, 10, -10, -10],
         [11, 5, 11, -11, -11],
         [12, 12, 12, -12, -12],
])


        
def single (ds):
    for x in ds:
        print(x)
    
def timeseries_dataset_multistep_combined(features, label_slice, input_sequence_length, output_sequence_length, sequence_stride, batch_size):
    feature_ds = tf.keras.preprocessing.timeseries_dataset_from_array(features, None, sequence_length=input_sequence_length + output_sequence_length, sequence_stride=sequence_stride ,batch_size=batch_size, shuffle=False)


    return feature_ds

ds = timeseries_dataset_multistep_combined(simple_data_samples, slice(None, None, None), input_sequence_length=4, output_sequence_length=2, sequence_stride=2, batch_size=2)

single(ds)

This code creates the following outputs of batches:

tf.Tensor(
[[[  1   1   1   7  -1]
  [  2  -2   2  -2  -2]
  [  3   3   3  -3  -3]
  [ -4   4   4  -4  -4]
  [  5   5   5  -5  -5]
  [  6   6   6  -4  -1]]

 [[  3   3   3  -3  -3]
  [ -4   4   4  -4  -4]
  [  5   5   5  -5  -5]
  [  6   6   6  -4  -1]
  [  7   7   8  -7 -70]
  [  8   8   8  -8  -8]]], shape=(2, 6, 5), dtype=int64)
tf.Tensor(
[[[  5   5   5  -5  -5]
  [  6   6   6  -4  -1]
  [  7   7   8  -7 -70]
  [  8   8   8  -8  -8]
  [  9   4   9  -9  -9]
  [ 10  10  10 -10 -10]]

 [[  7   7   8  -7 -70]
  [  8   8   8  -8  -8]
  [  9   4   9  -9  -9]
  [ 10  10  10 -10 -10]
  [ 11   5  11 -11 -11]
  [ 12  12  12 -12 -12]]], shape=(2, 6, 5), dtype=int64)

I want to manipulate each batch individually. For that purpose, I want to extract the max value from each batch individually. This can be done by the following code:

def timeseries_dataset_multistep_combined(features, label_slice, input_sequence_length, output_sequence_length, sequence_stride, batch_size):
    feature_ds = tf.keras.preprocessing.timeseries_dataset_from_array(features, None, sequence_length=input_sequence_length + output_sequence_length, sequence_stride=sequence_stride ,batch_size=batch_size, shuffle=False)
     
    def extract_max(x):
        return tf.reduce_max(x[:,:,-1],axis=1,keepdims=True)

    
    feature_ds = feature_ds.map(extract_max)


    return feature_ds

ds = timeseries_dataset_multistep_combined(simple_data_samples, slice(None, None, None), input_sequence_length=4, output_sequence_length=2, sequence_stride=2, batch_size=2)

single(ds)

Since I have created four batches, I would expect four max values as shown below:

tf.Tensor(
[[-1]
 [-1]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[-1]
 [-8]], shape=(2, 1), dtype=int64)

Now I would like to add each max value to its corresponding batch. For instance, for the first batch output I would add the first max value (-1) and expect the following output:

[[[  1   1   1   7  -1]
  [  2  -2   2  -2  -2]
  [  3   3   3  -3  -3]
  [ -4   4   4  -4  -4]
  [  5   5   5  -5  -5]
  [  6   6   6  -4  -1]] +(-1) ###first max value =

[[[  0   0   0   6  -2]
  [  1  -3   1  -3  -3]
  [  2   2   2  -4  -4]
  [ -5   3   3  -5  -5]
  [  4   4   4  -6  -6]
  [  5   5   5  -5  -2]] 

How would I code this?


You could use tf.repeat and try something like this:

def broad_cast_and_merge(ds1, ds2):
  ds1_shape = tf.shape(ds1)
  ds2 = tf.reshape(tf.repeat(ds2, repeats=ds1_shape[1] * ds1_shape[2]), ds1_shape)
  return ds1 + ds2

final_ds = tf.data.Dataset.zip((time_series_ds, max_ds)).map(broad_cast_and_merge)
single(final_ds)

Or with tf.broadcast_to:

def broad_cast_and_merge(ds1, ds2):
  ds2 = tf.expand_dims(ds2, axis=-1)
  ds2 = tf.broadcast_to(ds2, tf.shape(ds1))
  return ds1 + ds2

final_ds = tf.data.Dataset.zip((time_series_ds, max_ds)).map(broad_cast_and_merge)
single(final_ds)

Both will give you the same results:

tf.Tensor(
[[[  0   0   0   6  -2]
  [  1  -3   1  -3  -3]
  [  2   2   2  -4  -4]
  [ -5   3   3  -5  -5]
  [  4   4   4  -6  -6]
  [  5   5   5  -5  -2]]

 [[  2   2   2  -4  -4]
  [ -5   3   3  -5  -5]
  [  4   4   4  -6  -6]
  [  5   5   5  -5  -2]
  [  6   6   7  -8 -71]
  [  7   7   7  -9  -9]]], shape=(2, 6, 5), dtype=int64)
tf.Tensor(
[[[  4   4   4  -6  -6]
  [  5   5   5  -5  -2]
  [  6   6   7  -8 -71]
  [  7   7   7  -9  -9]
  [  8   3   8 -10 -10]
  [  9   9   9 -11 -11]]

 [[ -1  -1   0 -15 -78]
  [  0   0   0 -16 -16]
  [  1  -4   1 -17 -17]
  [  2   2   2 -18 -18]
  [  3  -3   3 -19 -19]
  [  4   4   4 -20 -20]]], shape=(2, 6, 5), dtype=int64)

As previously mentioned, you could also skip the broadcast step and just use tf.expand_dims.


You want broadcasting. Brodcasting only works from the right. To broadcast from the left, reshape your max tensor to shape (2,1,1) to fill out the extra dims on the right, then perform addition naively with a + sign. Broadcasting will take care of the rest.

For more info, check out https://www.tensorflow.org/api_docs/python/tf/broadcast_to