Partial evaluation with foolbox

Thomas Mailund

2018-12-15

An approach to runtime optimisation is partial evaluation. It isn’t that complicated an idea; if you are computing the same thing again and again, you would be better off computing it once and remember the result. This idea is used many places, both implementations and in the design of algorithms, where we call the idea things like dynamic programming or memorisation. When it comes to functions, though, it involves evaluating the parts of a function we are able to with the parameters we can fix, and keeping the rest of the function as code that must be evaluated at a later time.

Example

Consider a simple example where we have a function that scales one vector, y, using the mean and standard deviation of another vector x:

rescale <- function(x, y) {
    (y - mean(x)) / sd(x)
}

The function takes both vectors as input, but we can curry it to get a function of one argument, x, that returns another that also takes one argument, y, and then computes the same expression.

rescale_curry <- function(x) function(y) {
    (y - mean(x)) / sd(x)
}

This alone can seem a little pointless, but we see that mean(x) and sd(x) only depend on x, so if we need to use the function on many different y vectors, with the same x value, we could rewrite it so these two quantities are computed as soon as x is known, and such that we can reuse these computations for the different values of y.

rescale_curry_pe <- function(x) {
    mean_x <- mean(x)
    sd_x <- sd(x)
    function(y) (y - mean_x) / sd_x
}

Currying alone doesn’t change the performance, but combined with the partial evaluation—computing what we can when we know x and before we know y, can have a substantial effect.

time_uncurry <- function(f) {
    n <- 500
    x <- rnorm(1000,  0,   1)
    y <- rnorm(1000, 10, 100)
    replicate(n, f(x,y))
}
time_curry <- function(f) {
    n <- 500
    x <- rnorm(1000,  0,   1)
    y <- rnorm(1000, 10, 100)
    g <- f(x)
    replicate(n, g(y))
}

microbenchmark::microbenchmark(
    time_uncurry(rescale),
    time_curry(rescale_curry),
    time_curry(rescale_curry_pe)
)
#> Unit: milliseconds
#>                          expr       min        lq      mean    median
#>         time_uncurry(rescale) 23.985258 25.192411 28.253877 27.955396
#>     time_curry(rescale_curry) 23.904110 25.672767 28.912103 27.632925
#>  time_curry(rescale_curry_pe)  3.594679  4.030582  6.582374  4.512094
#>         uq      max neval
#>  29.175007 52.64393   100
#>  29.895446 63.98742   100
#>   7.444408 47.51016   100

The actual performance difference depends, of course, on how many times we call the partial-evaluated function and how complex it is.

Automatic partial evaluation

We can use foolbox to automatically partial-evaluate a function by replacing a variable for a value. In the simplest possible form, we simply substitute a value for a variable:

library(foolbox)

subst_var_callback <- function(var, val) {
    function(expr, ...) {
        if (expr == var) rlang::expr(!!val)
        else expr
    }
}

pe <- function(fn, var, val) {
    var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
    force(val)
    fn %>% rewrite() %>% rewrite_with(
        rewrite_callbacks() %>%
            with_symbol_callback(subst_var_callback(var, val))
    ) %>% remove_formal_(var)
}

We can see this in action by trying it out on the scale function:

test_x <- rnorm(3, 0, 1) %>% round(digits = 2)
rescale_pe <- rescale %>% pe(x, test_x)
rescale_pe
#> function (y) 
#> {
#>     (y - mean(c(0.82, 0.17, 0.39)))/sd(c(0.82, 0.17, 0.39))
#> }

test_y <- rnorm(5, 10, 100)
rescale(test_x, test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914
rescale_pe(test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914

We have inserted a value for x, but that is all. We have also, as it turns out, managed to make our program slower, because even though we run the partially evaluated function 500 times, the cost in runtime of transforming it outweighs whatever gain we get from the transformation.

microbenchmark::microbenchmark(
    time_uncurry(rescale), 
    time_curry(function(x) pe(rescale, x, x))
)
#> Unit: milliseconds
#>                                       expr     min       lq     mean
#>                      time_uncurry(rescale) 24.0924 24.66969 28.31623
#>  time_curry(function(x) pe(rescale, x, x)) 33.8725 35.51176 38.67932
#>    median       uq      max neval
#>  27.71651 28.69023 84.97467   100
#>  37.71341 39.40081 85.86342   100

This is perhaps not so surprising considering that we do not actually evaluate the mean(x) and sd(x) expressions we wanted to evaluate.

We can fix this by trying to evaluate all calls if all their arguments are evaluated, i.e. if all their arguments are atomic.

eval_attempt <- function(expr, env) {
    tryCatch(rlang::expr(!!eval(expr, env)),
             error = function(e) expr)
}

call_callback <- function(expr, env, ...) {
    call_args_atomic <- rlang::call_args(expr) %>% 
        Map(rlang::is_atomic, .) %>% unlist
    if (any(!call_args_atomic)) expr
    else eval_attempt(expr, env)
}

pe <- function(fn, var, val) {
    var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
    force(val)
    fn %>% rewrite() %>% rewrite_with(
        rewrite_callbacks() %>%
            with_symbol_callback(subst_var_callback(var,val)) %>%
            with_call_callback(call_callback)
    ) %>% remove_formal_(var)
}

With this implementation we will actually evaluate the function calls in scale when we substitute x for a value:

rescale_pe <- rescale %>% pe(x, test_x)
rescale_pe
#> function (y) 
#> {
#>     (y - 0.46)/0.330605505096331
#> }
rescale(test_x, test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914
rescale_pe(test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914

We can also see that this improves the performance, although we only get back to the performance we got before the transformation, so for this example we are not gaining enough for it to be worth the effort. Stil, the idea is the important part, and it does show that you can do this with foolbox

microbenchmark::microbenchmark(
    time_uncurry(rescale), 
    time_curry(function(x) pe(rescale, x, x))
)
#> Unit: milliseconds
#>                                       expr      min       lq     mean
#>                      time_uncurry(rescale) 23.81961 24.44317 27.36736
#>  time_curry(function(x) pe(rescale, x, x)) 12.80204 13.37902 16.11078
#>    median       uq      max neval
#>  26.57972 28.88052 67.02671   100
#>  14.21857 17.32043 59.50037   100

Since we use env in the evaluation we also capture local variables where the function is defined:

make_rescale <- function(m) {
    function(x, y) (y - m(x))/sd(x)
}
closure_rescale <- make_rescale(mean)
closure_rescale %>% pe(x, test_x)
#> function (y) 
#> (y - 0.46)/0.330605505096331
#> <environment: 0x7f824b4e9888>

Function parameters and local variables

What about other local functions?

rescale_ms <- function(x, y, mean, sd)
    (y - mean(x))/sd(x)

rescale_ms %>% pe(x, test_x)
#> function (y, mean, sd) 
#> (y - 0.46)/0.330605505096331

We evaluate mean(x) and sd(x) in the environment of rescale_ms but not its execution environment. So we evaluate the expressions using the global functions. This will only give the right behaviour if it is these exact two functions that are passed as the arguments.

In general, we should not evaluate functions that are local variables, at least not in the function’s environment. We can fix this by checking if a function name is a local variable before we evaluate a call:

is_bound <- function(var, expr) {
    (rlang::is_symbol(var) || is.character(var)) &&
        as.character(var) %in% attr(expr, "bound")
}
call_callback <- function(expr, env, ...) {
    call_args_atomic <- rlang::call_args(expr) %>% 
        Map(rlang::is_atomic, .) %>% unlist
    if (any(!call_args_atomic) || is_bound(expr[[1]], expr)) expr
    else eval_attempt(expr, env) 
}

pe <- function(fn, var, val) {
    var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
    force(val)
    fn %>% rewrite() %>% rewrite_with(
        rewrite_callbacks() %>%
            with_symbol_callback(subst_var_callback(var,val)) %>%
            with_call_callback(call_callback)
    ) %>% remove_formal_(var)
}

rescale %>% pe(x, test_x)
#> function (y) 
#> {
#>     (y - 0.46)/0.330605505096331
#> }
rescale_ms %>% pe(x, test_x)
#> function (y, mean, sd) 
#> (y - mean(c(0.82, 0.17, 0.39)))/sd(c(0.82, 0.17, 0.39))

Since functions are also data, we can do partial evaluation with them:

rescale_ms %>% pe(x, test_x) %>% pe(mean, mean)
#> function (y, sd) 
#> (y - 0.46)/sd(c(0.82, 0.17, 0.39))
rescale_ms %>% pe(x, test_x) %>% pe(mean, mean) %>% pe(sd, sd)
#> function (y) 
#> (y - 0.46)/0.330605505096331

What about local variables?

rescale_nested <- function(x, y) {
    move <- function(x, y) y - mean(x)
    move(x, y) / sd(x)
}
rescale_nested %>% pe(x, test_x)
#> function (y) 
#> {
#>     move <- function(x, y) y - 0.46
#>     move(c(0.82, 0.17, 0.39), y)/0.330605505096331
#> }

Here we are substituting x inside the body of move. It turns out to be okay because we call move with x later, but that is just dumb luck. In general, we don’t want to substitute variables that are local to another local function. We can use a top-down callback to skip a function body if it has the variable we are substituting as a parameter:

skip_overscoped_functions <- function(var_symbol) {
    var <- as.character(var_symbol)
    function(expr, skip, ...) {
        # in `function` expressions, the pair-list is element 2
        if (var %in% names(expr[[2]])) {
            skip(expr)
        }
    }
}
pe <- function(fn, var, val) {
    var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
    force(val)
    fn %>% rewrite() %>% rewrite_with(
        rewrite_callbacks() %>%
            with_symbol_callback(subst_var_callback(var,val)) %>%
            with_call_callback(call_callback) %>%
            add_topdown_callback(`function`, skip_overscoped_functions(var))
    ) %>% remove_formal_(var)
}

rescale_nested %>% pe(x, test_x)
#> function (y) 
#> {
#>     move <- function(x, y) y - mean(x)
#>     move(c(0.82, 0.17, 0.39), y)/0.330605505096331
#> }

Evaluating local functions

Ok, so now we avoid substituting variables inside a local function, but in this particular case, we would like to substitute in the call to move so we can get mean(x) evaluated there.

To handle local functions there is a bit of work to be done. First, we need a way to evaluate them, and that means partial evaluation. We want to substitute the parameters we know the values of and let the rest remain. We can do this by iteratively going through the parameters in a call and substitute variables like this:

pe_call <- function(fun, expr) {
    params <- rlang::call_args(match.call(fun, expr))
    substitutable <- Map(rlang::is_atomic, params) %>% unlist
    not_substitutable <- params[!substitutable]
    params <- params[substitutable]
    param_names <- names(params)
    
    for (i in seq_along(params)) {
        var <- as.symbol(param_names[i])
        fun <- fun %>% pe_(var, params[[i]])
    }
    
    rlang::expr((!!fun)(!!!not_substitutable))
}

Here, I assume that I have a pe_ function that works like pe but where var is already quoted. I define it below (and redefine pe in terms of it).

With this function we can handle a call. Now, we need to collect all local function expressions and evaluate them into actual functions. We do this in a callback to assignments (<- and =). I’m not entirely sure how to deal with variables in local function’s closure. In theory we should be able to do some substitution inside the function and leave those variables alone, but I plan to evaluate the expressions that define the functions, and I need an environment to do this in. This cannot be the execution environment—it doesn’t exist before we call the function—so I can only use the function environment. The local variables do not exist there. So I won’t allow the local function to refer to local variables from outsides its own scope (parameters and local variables inside the function) or the variable we are substituting. This function extracts all symbols used in an expression:

collect_all_symbols <- function(expr) {
    expr %>% analyse_expr() %>% analyse_expr_with(
        analysis_callbacks() %>% with_symbol_callback(
            function(expr, ...) list(symbols = as.character(expr))
    )) %>% unlist(use.names = FALSE)
}

and I can use it in this function for getting hold of local functions:

save_local_functions <- function(expr, env, params, var, fun_table, ...) {
    if (!rlang::is_symbol(expr[[1]])) 
        return(expr)
    if (!rlang::is_call(expr[[3]]) || length(expr[[3]]) < 2 ||
        expr[[3]][[1]] != "function") 
        return(expr)
    
    fun_name <- as.character(expr[[2]])
    fun_formals <- expr[[3]][[2]]
    fun_body <- expr[[3]][[3]]
    
    fun_parameters <- names(fun_formals)
    fun_used_vars <- collect_all_symbols(fun_body)
    local_vars <- attr(expr, "bound")
    
    # figure out if the function uses local variables that are not
    # passed through the formal parameters
    closure_scope <- setdiff(fun_used_vars, fun_parameters)
    local_closure_scope <- intersect(local_vars, closure_scope)
    if (length(local_closure_scope) > 2) 
        return(expr) 
    if (length(local_closure_scope) == 1 &&
        local_closure_scope != as.character(var))       
        return(expr)
    
    # create the function and save it as an attribute
    stopifnot(!exists(fun_name, fun_table))
    fun_table[[fun_name]] <- eval(expr[[3]], env)
    
    expr
}

I store the local functions in a table, fun_table. I will use an environment here, so the callback has side-effects. It is the easiest way to implement this table.

In the call callback I then do partial evaluation on local variables:

call_callback <- function(expr, env, fun_table, ...) {
    call_args_atomic <- rlang::call_args(expr) %>% 
        Map(rlang::is_atomic, .) %>% unlist
    
    fun_name <- as.character(expr[[1]])
    if (all(call_args_atomic) && !is_bound(fun_name, expr))
        return(eval_attempt(expr, env))

    # if it exists in the function table, we can try a partial evaluation 
    # of it
    if (is_bound(fun_name, expr) && exists(fun_name, fun_table)) {
        fun <- fun_table[[fun_name]]
        
        # if we have all the arguments, we can just call the function
        # and insert the result
        if (all(call_args_atomic))
            return(do.call(fun, rlang::call_args(expr)))
        
        # otherwise, we have to try a partial substitution of the function
        return(pe_call(fun, expr))
    }
        
    expr
}

We can now put all this together to get a new pe (and pe_) function:

pe_ <- function(fn, var, val) {
    stopifnot(rlang::is_symbol(var))
    force(val)
    fn %>% rewrite() %>% rewrite_with(
        rewrite_callbacks() %>%
            with_symbol_callback(subst_var_callback(var,val)) %>%
            with_call_callback(call_callback) %>%
            add_topdown_callback(`function`, skip_overscoped_functions(var)) %>%
            add_topdown_callback(`<-`, save_local_functions) %>%
            add_topdown_callback(`=`, save_local_functions),
        var = var, fun_table = rlang::new_environment()
    ) %>% remove_formal_(var)
}
pe <- function(fn, var, val) {
    var <- rlang::enexpr(var)
    pe_(fn, var, val)
}

rescale_nested <- function(x, y) {
    move <- function(x, y) y - mean(x)
    move(x, y) / sd(x)
}
rescale_nested %>% pe(x, test_x)
#> function (y) 
#> {
#>     move <- function(x, y) y - mean(x)
#>     (function (y) 
#>     y - 0.46)(y = y)/0.330605505096331
#> }
rescale_nested_pe <- rescale_nested %>% pe(x, test_x)
rescale_nested_pe
#> function (y) 
#> {
#>     move <- function(x, y) y - mean(x)
#>     (function (y) 
#>     y - 0.46)(y = y)/0.330605505096331
#> }

rescale(test_x, test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914
rescale_nested(test_x, test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914
rescale_nested_pe(test_y)
#> [1]  207.4550  376.3745 -762.1280 -334.4448  164.2914

As it turns out, this solution puts an expression in the scope where local variables will be present, so I didn’t have to do all that checking of the scope. So I will remove it again.

save_local_functions <- function(expr, env, params, var, fun_table, ...) {
    if (!rlang::is_symbol(expr[[1]])) 
        return(expr)
    if (!rlang::is_call(expr[[3]]) || length(expr[[3]]) < 2 ||
        expr[[3]][[1]] != "function") 
        return(expr)
    
    fun_name <- as.character(expr[[2]])
    stopifnot(!exists(fun_name, fun_table))
    fun_table[[fun_name]] <- eval(expr[[3]], env)
    
    expr
}
rescale_nested <- function(x, y, m) {
    mean <- m
    move <- function(x, y) y - mean(x)
    move(x, y) / sd(x)
}

rescale_nested %>% pe(x, test_x)
#> function (y, m) 
#> {
#>     mean <- m
#>     move <- function(x, y) y - mean(x)
#>     (function (y) 
#>     y - 0.46)(y = y)/0.330605505096331
#> }

This example, however, illustrates that there is still some work that could be done. We could propagate the local value of mean into move. Trying for that, though, leads us into territory I don’t want to enter. It is undecidable to figure out what a local variable is set to at any particular point in a program unless we put some serious restrictions on it or do some serious static analysis to handle special cases where we are able to determine it. For this document, it takes us too far into static analysis, so I will stop here.

In any case, we are now doing so much work with the partial evaluation that we are only making the program slower, and as it is, we only barely got better than the simple rescale function earlier.

I do not know if there is a use for partial evaluation in R. The language is not built for speed in the first place, and optimisations would be better handled in an adaptive virtual machine. If you can find a use for partial evaluation, though, I think you can implement it using foolbox.