dplyr - filter by group size

What is the best way to filter a data.frame to only get groups of say size 5?

So my data looks as follows:

n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]

dat <- data.frame(x = x, cat = cat)

The dplyr way i could come up with was

dat <- group_by(dat, cat)

  out1 <- dat %>% filter(n() == 5L)
#    user  system elapsed 
#   1.157   0.218   1.497

But this is very slow... Is there a better way in dplyr?

So far my workaround solutions looks as follows:

  all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
  take_only <- which(group_size(dat) == 5L)
  out2 <- dat[all_ind %in% take_only, ]
#    user  system elapsed 
#   0.026   0.008   0.036
all.equal(out1, out2) # TRUE

But this doesn't feel very dplyr like...

Solution 1:

You can do it more concisely with n():

dat %>% group_by(cat) %>% filter(n() == 5)

Solution 2:

Here's another dplyr approach you can try

semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")


Here's another approach based on OP's original approach with a little modification:

n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]

dat <- data.frame(x = x, cat = cat)

# second data set for the dt approch
dat2 <- data.frame(x = x, cat = cat)

sol_floo0 <- function(dat){
  dat <- group_by(dat, cat)
  all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
  take_only <- which(group_size(dat) == 5L)
  dat[all_ind %in% take_only, ]

sol_floo0_v2 <- function(dat){
  g <- group_by(dat, cat) %>% group_size()
  ind <- rep(g == 5, g)
  dat[ind, ]

microbenchmark::microbenchmark(times = 10,
#Unit: milliseconds
#               expr      min       lq     mean   median       uq      max neval cld
#     sol_floo0(dat) 43.72903 44.89957 45.71121 45.10773 46.59019 48.64595    10   b
# sol_floo0_v2(dat2) 29.83724 30.56719 32.92777 31.97169 34.10451 38.31037    10  a 
all.equal(sol_floo0(dat), sol_floo0_v2(dat2))
#[1] TRUE

Solution 3:

I know you asked for a dplyr solution but if you combine it with some purrr you can get it in one line without specifying any new functions. (A little slower though.)


dat %>% 
  group_by(cat) %>% 
  nest() %>% 
  mutate(n = map(data, n_distinct)) %>%
  unnest(n = n) %>% 
  filter(n == 5) %>% 
  select(cat, n)

Solution 4:

Comparing the answers timewise:

n <- 1e5
x <- rnorm(n)
# Category size ranging each from 1 to 5
cat <- rep(seq_len(n/3), sample(1:5, n/3, replace = TRUE))[1:n]

dat <- data.frame(x = x, cat = cat)

# second data set for the dt approch
dat2 <- data.frame(x = x, cat = cat)

sol_floo0 <- function(dat){
  dat <- group_by(dat, cat)
  all_ind <- rep(seq_len(n_groups(dat)), group_size(dat))
  take_only <- which(group_size(dat) == 5L)
  dat[all_ind %in% take_only, ]

sol_floo0_v2 <- function(dat){
  g <- group_by(dat, cat) %>% group_size()
  ind <- rep(g == 5, g)
  dat[ind, ]

sol_docendo_discimus <- function(dat){ 
  dat <- group_by(dat, cat)
  semi_join(dat, count(dat, cat) %>% filter(n == 5), by = "cat")

sol_akrun <- function(dat2){
  setDT(dat2)[dat2[, .I[.N==5], by = cat]$V1]

sol_sotos <- function(dat2){
  setDT(dat2)[, if(.N == 5) .SD, by = cat]

sol_chirayu_chamoli <- function(dat){
  rle_ <- rle(dat$cat)
  dat[dat$cat %in% rle_$values[rle_$lengths==5], ]

microbenchmark::microbenchmark(times = 20,

Results in:

Unit: milliseconds
                      expr       min        lq      mean    median        uq       max neval  cld
            sol_floo0(dat)  58.00439  65.28063  93.54014  69.82658  82.79997 280.23114    20   cd
         sol_floo0_v2(dat)  42.27791  50.27953  72.51729  58.63931  67.62540 238.97413    20  bc 
 sol_docendo_discimus(dat) 100.54095 113.15476 126.74142 121.69013 132.62533 183.05818    20    d
           sol_akrun(dat2)  26.88369  34.01925  41.04378  37.07957  45.44784  63.95430    20 ab  
           sol_sotos(dat2)  16.10177  19.78403  24.04375  23.06900  28.05470  35.83611    20 a   
  sol_chirayu_chamoli(dat)  20.67951  24.18100  38.01172  27.61618  31.97834 230.51026    20 ab  

Solution 5:

I generalised the function written by docendo discimus, to use it alongside existing dplyr functions:

#' inherit dplyr::filter
#' @param min minimal group size, use \code{min = NULL} to filter on maximal group size only
#' @param max maximal group size, use \code{max = NULL} to filter on minimal group size only
#' @export
#' @source Stack Overflow answer by docendo discimus, \url{https://stackoverflow.com/a/43110620/4575331}
filter_group_size <- function(.data, min = NULL, max = min) {
  g <- dplyr::group_size(.data)
  if (is.null(min) & is.null(max)) {
    stop('`min` and `max` cannot both be NULL.')
  if (is.null(max)) {
    max <- base::max(g, na.rm = TRUE)
  ind <- base::rep(g >= min & g <= max, g)
  .data[ind, ]

Let's check it for a minimal group size of 5:

dat2 %>%
  group_by(cat) %>%
  filter_group_size(5, NULL) %>%
  summarise(n = n()) %>%

# # A tibble: 6,634 x 2
#      cat     n
#    <int> <int>
#  1    NA    19
#  2     1     5
#  3     2     5
#  4     6     5
#  5    15     5
#  6    17     5
#  7    21     5
#  8    27     5
#  9    33     5
# 10    37     5
# # ... with 6,624 more rows

Great, now check for the OP's question; a group size of exactly 5:

dat2 %>%
  group_by(cat) %>%
  filter_group_size(5) %>%
  summarise(n = n()) %>%
  pull(n) %>%
# [1] 5
