Reduce processing time for calculating coefficients
Solution 1:
There are too many issues in your code. We need to work from scratch. In general, here are some major concerns:
-
Don't do expensive operations so many times. Things like
pivot_*
and*_join
are not cheap since they change the structure of the entire dataset. Don't use them so freely as if they come with no cost. -
Do not repeat yourself. I saw
filter(Id == idd, Category == ...)
several times in your function. The rows that are filtered out won't come back. This is just a waste of computational power and makes your code unreadable. -
Think carefully before you code. It seems that you want the regression results for multiple
idd
,date2
andCategory
. Then, should the function be designed to only take scalar inputs so that we can run it many times each involving several expensive data operations on a relatively large dataset, or should it be designed to take vector inputs, do fewer operations, and return them all at once? The answer to this question should be clear.
Now I will show you how I would approach this problem. The steps are
-
Find the relevant subset for each group of
idd
,dmda
andCategoryChosse
at once. We can use one or two joins to find the corresponding subset. Since we also need to calculate the median for eachWeek
group, we would also want to find the corresponding dates that are in the sameWeek
group for eachdmda
. -
Pivot the data from wide to long, once and for all. Use row id to preserve row relationships. Call the column containing those "DRMXX"
day
and the column containing valuesvalue
. -
Find if trailing zeros exist for each row id. Use
rev(cumsum(rev(x)) != 0)
instead of a long and inefficient pipeline. -
Compute the median-adjusted values by each group of "Id", "Category", ..., "day", and "Week". Doing things by group is natural and efficient in a long data format.
-
Aggregate the
Week
group. This follows directly from your code, while we will also filter outday
s that are smaller than the difference between eachdmda
and the correspondingdate1
for each group. -
Run
lm
for each group ofId
,Category
anddmda
identified. -
Use
data.table
for greater efficiency. -
(Optional) Use a different
median
function rewritten in c++ since the one in base R (stats::median
) is a bit slow (stats::median
is a generic method considering various input types but we only need it to take numerics in this case). The median function is adapted from here.
Below shows the code that demonstrates the steps
Rcpp::sourceCpp(code = '
#include <Rcpp.h>
// [[Rcpp::export]]
double mediancpp(Rcpp::NumericVector& x, const bool na_rm) {
std::size_t m = x.size();
if (m < 1) Rcpp::stop("zero length vector not allowed.");
if (!na_rm) {
for (Rcpp::NumericVector::iterator i = x.begin(); i != x.end(); ++i)
if (Rcpp::NumericVector::is_na(*i)) return *i;
} else {
for (Rcpp::NumericVector::iterator i = x.begin(); i != x.begin() + m; )
Rcpp::NumericVector::is_na(*i) ? std::iter_swap(i, x.begin() + --m) : (void)++i;
}
if (m < 1) return x[0];
std::size_t n = m / 2;
std::nth_element(x.begin(), x.begin() + n, x.begin() + m);
return m % 2 ? x[n] : (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}
')
dt_return_intercept <- function(dt1, idd, dmda, category) {
# type checks
stopifnot(
data.table::is.data.table(dt1),
length(idd) == length(dmda),
length(idd) == length(category)
)
dmda <- switch(
class(dt1$date2),
character = as.character(dmda), Date = as.Date(dmda, "%Y-%m-%d"),
stop("non-comformable types between `dmda` and `dt1$date2`")
)
idd <- as(idd, class(dt1$Id))
# find subsets
DT <- data.table::setDT(list(Id = idd, date2 = dmda, Category = category, order = seq_along(idd)))
DT <- dt1[
dt1[DT, .(Id, Category, date2, Week, order), on = .NATURAL],
on = .(Id, Category, Week), allow.cartesian = TRUE
]
DT[, c("rowid", "date1", "date2", "i.date2") := c(
list(seq_len(.N)), lapply(.SD, as.Date, "%Y-%m-%d")
), .SDcols = c("date1", "date2", "i.date2")]
# pivot + type conversion
DT <- data.table::melt(DT, measure = patterns("DRM(\\d+)"), variable = "day")
DT[, `:=`(day = as.integer(sub("^\\D+", "", day)), value = as.numeric(value))]
# computations
DT[, keep := rev(cumsum(rev(value)) != 0), by = "rowid"]
DT[, value := value + mediancpp(DR1 - value, TRUE),
by = c("Id", "Category", "i.date2", "date1", "day", "Week")]
DT <- DT[date2 == i.date2 & keep & day > i.date2 - date1,
.(value = sum(value), order = order[[1L]]),
by = c("Id", "Category", "i.date2", "date1", "day")]
DT[, .(out = coef(lm(value ~ I(day^2), .SD))[[1L]], order = order[[1L]]), # coef(...)[[1L]] gives you the intercept, not the coefficient of day^2. Are you sure this is what you want?
by = c("Id", "Category", "i.date2")][order(order)]$out
}
Benchmark
params <- (params <- unique(df1[df1$date1 <= df1$date2, c(1L, 3L, 4L)]))[sample.int(nrow(params), 20L), ]
dt1 <- data.table::setDT(data.table::copy(df1)) # nothing but a data.table version of `df1`
microbenchmark::microbenchmark(
mapply(function(x, y, z) return_coef(df1, x, y, z),
params$Id, params$date2, params$Category),
dt_return_intercept(dt1, params$Id, params$date2, params$Category),
dt_return_intercept_base(dt1, params$Id, params$date2, params$Category), # use stats::median instead of mediancpp
times = 10L, check = "equal"
)
Results are as follows. No error is thrown with check="equal"
. This means all three functions return the same result. This function is about 136x faster than yours with mediancpp
, and about 73x faster than yours with stats::median
. To avoid copy, mediancpp
takes its first argument by reference. Therefore, it needs to be used with caution. This behavior fits well in this case since DR1 - value
creates a temporary object that affects none of our variables.
Unit: milliseconds
expr min lq mean median uq max neval
mapply(function(x, y, z) return_coef(df1, x, y, z), params$Id, params$date2, params$Category) 11645.1729 11832.4373 11902.36716 11902.95195 11979.4154 12145.1154 10
dt_return_intercept(dt1, params$Id, params$date2, params$Category) 68.3173 72.4008 87.14596 75.24725 88.6007 167.2546 10
dt_return_intercept_base(dt1, params$Id, params$date2, params$Category) 153.9713 157.0826 163.18133 162.12175 167.2681 176.6866 10