Speed comparison

Neuroblastoma data

Consider the neuroblastoma data. There are 3418 labeled examples. If we consider subsets, how long does it take to compute the AUM and its directional derivatives?

``````data(neuroblastomaProcessed, package="penaltyLearning")
library(data.table)
nb.err <- data.table(neuroblastomaProcessed\$errors)
nb.err[, example := paste0(profile.id, ".", chromosome)]
nb.X <- neuroblastomaProcessed\$feature.mat
(N.pred.vec <- as.integer(10^seq(1, log10(nrow(nb.X)), by=0.5)))
#> [1]   10   31  100  316 1000 3162
if(requireNamespace("atime")){
aum.pL.list <- atime::atime(
N=N.pred.vec,
setup={
N.pred.names <- rownames(nb.X)[1:N]
N.diffs.dt <- aum::aum_diffs_penalty(nb.err, N.pred.names)
pred.dt <- data.table(example=N.pred.names, pred.log.lambda=0)
},
penaltyLearning={
roc.list <- penaltyLearning::ROChange(nb.err, pred.dt, "example")
},
aum={
aum.list <- aum::aum(N.diffs.dt, pred.dt\$pred.log.lambda)
})
plot(aum.pL.list)
}
#> Warning in grid.Call.graphics(C_polygon, x\$x, x\$y, index): semi-transparency is
#> not supported on this device: reported only once per page
``````

From the plot above we can see that both packages have similar asymptotic time complexity. However aum is faster by orders of magnitude.

R implementation

In this section we show a base R implementation of aum.

``````diffs.df <- data.frame(
example=c(0,1,1,2,3),
pred=c(0,0,1,0,0),
fp_diff=c(1,1,1,0,0),
fn_diff=c(0,0,0,-1,-1))
pred.log.lambda <- c(0,1,-1,0)
microbenchmark::microbenchmark("C++"={
aum::aum(diffs.df, pred.log.lambda)
}, R={
thresh.vec <- with(diffs.df, pred-pred.log.lambda[example+1])
s.vec <- order(thresh.vec)
sort.diffs <- data.frame(diffs.df, thresh.vec)[s.vec,]
for(fp.or.fn in c("fp","fn")){
ord.fun <- if(fp.or.fn=="fp")identity else rev
fwd.or.rev <- sort.diffs[ord.fun(1:nrow(sort.diffs)),]
fp.or.fn.diff <- fwd.or.rev[[paste0(fp.or.fn,"_diff")]]
last.in.run <- c(diff(fwd.or.rev\$thresh.vec) != 0, TRUE)
after.or.before <-
ifelse(fp.or.fn=="fp",1,-1)*cumsum(fp.or.fn.diff)[last.in.run]
distribute <- function(values)with(fwd.or.rev, structure(
values,
names=thresh.vec[last.in.run]
)[paste(thresh.vec)])
out.df <- data.frame(
before=distribute(c(0, after.or.before[-length(after.or.before)])),
after=distribute(after.or.before))
sort.diffs[
paste0(fp.or.fn,"_",ord.fun(c("before","after")))
] <- as.list(out.df[ord.fun(1:nrow(out.df)),])
}
AUM.vec <- with(sort.diffs, diff(thresh.vec)*pmin(fp_before,fn_before)[-1])
list(
aum=sum(AUM.vec),
deriv_mat=sapply(c("after","before"),function(b.or.a){
s <- if(b.or.a=="before")1 else -1
f <- function(p.or.n,suffix=b.or.a){
sort.diffs[[paste0("f",p.or.n,"_",suffix)]]
}
fp <- f("p")
fn <- f("n")
aggregate(
s*(pmin(fp+s*f("p","diff"),fn+s*f("n","diff"))-pmin(fp, fn)),
list(sort.diffs\$example),
sum)\$x
}))
}, times=10)
#> Unit: microseconds
#>  expr      min       lq      mean   median       uq      max neval cld
#>   C++   592.00   594.68   651.496   632.54   653.08   896.16    10  a
#>     R 42077.68 42186.36 46758.920 43509.78 47050.16 63950.32    10   b
``````

It is clear that the C++ implementation is several orders of magnitude faster.

Synthetic data

``````library(data.table)
max.N <- 1e6
(N.pred.vec <- as.integer(10^seq(1, log10(max.N), by=0.5)))
#>  [1]      10      31     100     316    1000    3162   10000   31622  100000
#> [10]  316227 1000000
max.y.vec <- rep(c(0,1), l=max.N)
max.diffs.dt <- aum::aum_diffs_binary(max.y.vec)
set.seed(1)
max.pred.vec <- rnorm(max.N)
if(requireNamespace("atime")){
aum.sort.list <- atime::atime(
N=N.pred.vec,
setup={
N.diffs.dt <- max.diffs.dt[1:N]
N.pred.vec <- max.pred.vec[1:N]
},
dt_sort={
N.diffs.dt[order(N.pred.vec)]
},
},
R_sort_quick={
sort(N.pred.vec, method="quick")
},
aum_sort={
aum.list <- aum:::aum_sort_interface(N.diffs.dt, N.pred.vec)
})
plot(aum.sort.list)
}
#> Warning: Transformation introduced infinite values in continuous y-axis
#> Transformation introduced infinite values in continuous y-axis
#> Transformation introduced infinite values in continuous y-axis
#> Warning in grid.Call.graphics(C_polygon, x\$x, x\$y, index): semi-transparency is
#> not supported on this device: reported only once per page
``````