Function Currying in C

Example


/**
 * Adapted from awesome answer in function curring in ANSI-C, x86-64 platform
 * https://stackoverflow.com/questions/1023261/is-there-a-way-to-do-currying-in-c
*/
   
   
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include <sys/mman.h>
#include <stdint.h> 
#include <string.h>


#define P_PAGE_START(P) ((uintptr_t)(P) & ~(pagesize-1))
#define P_PAGE_END(P)   (((uintptr_t)(P) + pagesize - 1) & ~(pagesize-1))

/* x86-64 ABI passes parameters in rdi, rsi, rdx, rcx, r8, r9
(https://wiki.osdev.org/System_V_ABI), and return value
goes in %rax.
Binary format of useful opcodes:
      0xbf, [le32] = movl $imm32, %edi (1st param)
      0xbe, [le32] = movl $imm32, %esi (2nd param)
      0xba, [le32] = movl $imm32, %edx (3rd param)
      0xb9, [le32] = movl $imm32, %ecx (4rd param)
      0xb8, [le32] = movl $imm32, %eax
0x48, 0x__, [le64] = movq $imm64, %r__
      0xff, 0xe0   = jmpq *%rax


typedef uint32_t (*partial_function)(int argc, ...);

partial_function partial(int argc, ...)
{
 int n_args = argc - 1;
 va_list ap;
 va_start(ap, argc);
 uintptr_t fp;
 /* The first parameter is the function to be curried */
 fp = (uintptr_t) va_arg(ap, void *); 
 if (argc == 1) return (partial_function) fp;
 if (argc > 4) return NULL;
 uint32_t fixed_param;
 uint8_t opcode[] = {
  0xbe, 0xba, 0xb9
 };
 uint8_t bytecode[] = {
  0x48, 0xb8, fp >>  0, fp >>  8, fp >> 16, fp >> 24, /* movq fp, %rax */
		   fp >> 32, fp >> 40, fp >> 48, fp >> 56,
  0xff, 0xe0 /*jmpq *%rax */
 };
 uint8_t *buf = (uint8_t *)calloc(5 * n_args + sizeof(bytecode), 1);
 
 for (int i = argc - 1, j = 0 ; i > 0; --i, ++j) {
  fixed_param = va_arg(ap, uint32_t);
  buf[5*j] = opcode[j];
  buf[5*(i-1)+1] = fixed_param >> 0;
  buf[5*(i-1)+2] = fixed_param >> 8;
  buf[5*(i-1)+3] = fixed_param >> 16;
  buf[5*(i-1)+4] = fixed_param >> 24;
 }
 /* Now we create a copy of this template on the HEAP, and
 fill in the arguments. */
 memcpy(buf + (5 * n_args), bytecode, sizeof(bytecode));
 uintptr_t pagesize = sysconf(_SC_PAGE_SIZE);
 mprotect((void *)P_PAGE_START(buf),
  P_PAGE_END(buf + (5 * n_args) + sizeof(bytecode)) - P_PAGE_START(buf),
  PROT_READ | PROT_WRITE | PROT_EXEC
 );
 va_end(ap);
 return (partial_function)buf;
}


int print_both_params(int a, int b, int c)
{
    printf("Called with a=%d, b=%d, c=%d\n", a, b, c);
    return a+b;
}

int main(int argc, char const *argv[])
{
    partial_function fixed_first = partial(3, print_both_params, 3, 2);
    fixed_first(1);
    return 0;
}