Parallel flatMap always sequential
Suppose I have this code:
Collections.singletonList(10)
.parallelStream() // .stream() - nothing changes
.flatMap(x -> Stream.iterate(0, i -> i + 1)
.limit(x)
.parallel()
.peek(m -> {
System.out.println(Thread.currentThread().getName());
}))
.collect(Collectors.toSet());
Output is the same thread name, so there is no benefit from parallel
here - what I mean by that is that there is a single thread that does all the work.
Inside flatMap
there is this code:
result.sequential().forEach(downstream);
I understand forcing the sequential
property if the "outer" stream would be parallel (they could probably block), "outer" would have to wait for "flatMap" to finish and the other way around (since the same common pool is used) But why always force that?
Is that one of those things that could change in a later version?
Solution 1:
There are two different aspects.
First, there is only a single pipeline which is either sequential or parallel. The choice of sequential or parallel at the inner stream is irrelevant. Note that the downstream
consumer you see in the cited code snippet represents the entire subsequent stream pipeline, so in your code, ending with .collect(Collectors.toSet());
, this consumer will eventually add the resulting elements to a single Set
instance which is not thread safe. So processing the inner stream in parallel with that single consumer would break the entire operation.
If an outer stream gets split, that cited code might get invoked concurrently with different consumers adding to different sets. Each of these calls would process a different element of the outer stream mapping to a different inner stream instance. Since your outer stream consists of a single element only, it can’t be split.
The way, this has been implemented, is also the reason for the Why filter() after flatMap() is “not completely” lazy in Java streams? issue, as forEach
is called on the inner stream which will pass all elements to the downstream consumer. As demonstrated by this answer, an alternative implementation, supporting laziness and substream splitting, is possible. But this is a fundamentally different way of implementing it. The current design of the Stream implementation mostly works by consumer composition, so in the end, the source spliterator (and those split off from it) receives a Consumer
representing the entire stream pipeline in either tryAdvance
or forEachRemaining
. In contrast, the solution of the linked answer does spliterator composition, producing a new Spliterator
delegating to source spliterators. I supposed, both approaches have advantages and I’m not sure, how much the OpenJDK implementation would lose when working the other way round.
Solution 2:
For anyone like me, who has a dire need to parallelize flatMap and needs some practical solution, not only history and theory.
The simplest solution I came up with is to do flattening by hand, basically by replacing it with map + reduce(Stream::concat)
.
Here's an example to demonstrate how to do this:
@Test
void testParallelStream_NOT_WORKING() throws InterruptedException, ExecutionException {
new ForkJoinPool(10).submit(() -> {
Stream.iterate(0, i -> i + 1).limit(2)
.parallel()
// does not parallelize nested streams
.flatMap(i -> generateRangeParallel(i, 100))
.peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
.forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
}).get();
System.out.println("done");
}
@Test
void testParallelStream_WORKING() throws InterruptedException, ExecutionException {
new ForkJoinPool(10).submit(() -> {
Stream.iterate(0, i -> i + 1).limit(2)
.parallel()
// concatenation of nested streams instead of flatMap, parallelizes ALL the items
.map(i -> generateRangeParallel(i, 100))
.reduce(Stream::concat).orElse(Stream.empty())
.peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
.forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
}).get();
System.out.println("done");
}
Stream<Integer> generateRangeParallel(int start, int num) {
return Stream.iterate(start, i -> i + 1).limit(num).parallel();
}
// run this method with produced output to see how work was distributed
void countThreads(String strOut) {
var res = Arrays.stream(strOut.split("\n"))
.map(line -> line.split("\\s+"))
.collect(Collectors.groupingBy(s -> s[0], Collectors.counting()));
System.out.println(res);
System.out.println("threads : " + res.keySet().size());
System.out.println("work : " + res.values());
}
Stats from run on my machine:
NOT_WORKING case stats:
{ForkJoinPool-1-worker-23=100, ForkJoinPool-1-worker-5=300}
threads : 2
work : [100, 300]
WORKING case stats:
{ForkJoinPool-1-worker-9=16, ForkJoinPool-1-worker-23=20, ForkJoinPool-1-worker-21=36, ForkJoinPool-1-worker-31=17, ForkJoinPool-1-worker-27=177, ForkJoinPool-1-worker-13=17, ForkJoinPool-1-worker-5=21, ForkJoinPool-1-worker-19=8, ForkJoinPool-1-worker-17=21, ForkJoinPool-1-worker-3=67}
threads : 10
work : [16, 20, 36, 17, 177, 17, 21, 8, 21, 67]