Cartesian product of streams in Java 8 as stream (using streams only)

Solution 1:

Passing the streams in your example is never better than passing Lists:

private static <T> Stream<T> cartesian(BinaryOperator<T> aggregator, List<T>... lists) {
    ...
}

And use it like this:

Stream<String> result = cartesian(
  (a, b) -> a + b, 
  Arrays.asList("A", "B"), 
  Arrays.asList("K", "L"), 
  Arrays.asList("X", "Y")
);

In both cases you create an implicit array from varargs and use it as data source, thus the laziness is imaginary. Your data is actually stored in the arrays.

In most of the cases the resulting Cartesian product stream is much longer than the inputs, thus there's practically no reason to make the inputs lazy. For example, having five lists of five elements (25 in total), you will have the resulting stream of 3125 elements. So storing 25 elements in the memory is not very big problem. Actually in most of the practical cases they are already stored in the memory.

In order to generate the stream of Cartesian products you need to constantly "rewind" all the streams (except the first one). To rewind, the streams should be able to retrieve the original data again and again, either buffering them somehow (which you don't like) or grabbing them again from the source (colleciton, array, file, network, random numbers, etc.) and perform again and again all the intermediate operations. If your source and intermediate operations are slow, then lazy solution may be much slower than buffering solution. If your source is unable to produce the data again (for example, random numbers generator which cannot produce the same numbers it produced before), your solution will be incorrect.

Nevertheless totally lazy solution is possbile. Just use not streams, but stream suppliers:

private static <T> Stream<T> cartesian(BinaryOperator<T> aggregator,
                                       Supplier<Stream<T>>... streams) {
    return Arrays.stream(streams)
        .reduce((s1, s2) -> 
            () -> s1.get().flatMap(t1 -> s2.get().map(t2 -> aggregator.apply(t1, t2))))
        .orElse(Stream::empty).get();
}

The solution is interesting as we create and reduce the stream of suppliers to get the resulting supplier and finally call it. Usage:

Stream<String> result = cartesian(
          (a, b) -> a + b, 
          () -> Stream.of("A", "B"), 
          () -> Stream.of("K", "L"), 
          () -> Stream.of("X", "Y")
        );
result.forEach(System.out::println);

Solution 2:

stream is consumed in the flatMap operation in the second iteration. So you have to create a new stream every time you map your result. Therefore you have to collect the stream in advance to get a new stream in every iteration.

private static <T> Stream<T> cartesian(BiFunction<T, T, T> aggregator, Stream<T>... streams) {
    Stream<T> result = null;
    for (Stream<T> stream : streams) {
        if (result == null) {
            result = stream;
        } else {
            Collection<T> s = stream.collect(Collectors.toList());
            result = result.flatMap(m -> s.stream().map(n -> aggregator.apply(m, n)));
        }
    }
    return result;
}

Or even shorter:

private static <T> Stream<T> cartesian(BiFunction<T, T, T> aggregator, Stream<T>... streams) {
    return Arrays.stream(streams).reduce((r, s) -> {
        List<T> collect = s.collect(Collectors.toList());
        return r.flatMap(m -> collect.stream().map(n -> aggregator.apply(m, n)));
    }).orElse(Stream.empty());
}

Solution 3:

You can create a method that returns a stream of List<T> of objects and does not aggregate them. The algorithm is the same: at each step, collect the elements of the second stream to a list and then append them to the elements of the first stream.

The aggregator is outside the method.

@SuppressWarnings("unchecked")
public static <T> Stream<List<T>> cartesianProduct(Stream<T>... streams) {
    // incorrect incoming data
    if (streams == null) return Stream.empty();
    return Arrays.stream(streams)
            // non-null streams
            .filter(Objects::nonNull)
            // represent each list element as SingletonList<Object>
            .map(stream -> stream.map(Collections::singletonList))
            // summation of pairs of inner lists
            .reduce((stream1, stream2) -> {
                // list of lists from second stream
                List<List<T>> list2 = stream2.collect(Collectors.toList());
                // append to the first stream
                return stream1.flatMap(inner1 -> list2.stream()
                        // combinations of inner lists
                        .map(inner2 -> {
                            List<T> list = new ArrayList<>();
                            list.addAll(inner1);
                            list.addAll(inner2);
                            return list;
                        }));
            }).orElse(Stream.empty());
}
public static void main(String[] args) {
    Stream<String> stream1 = Stream.of("A", "B");
    Stream<String> stream2 = Stream.of("K", "L");
    Stream<String> stream3 = Stream.of("X", "Y");
    @SuppressWarnings("unchecked")
    Stream<List<String>> stream4 = cartesianProduct(stream1, stream2, stream3);
    // output
    stream4.map(list -> String.join("", list)).forEach(System.out::println);
}

String.join is a kind of aggregator in this case.

Output:

AKX
AKY
ALX
ALY
BKX
BKY
BLX
BLY

See also: Stream of cartesian product of other streams, each element as a List?