Computer Science

Here you'll find some information about programming languages and software development in general.

Currying in Python

Posted by mtomassoli on March 18, 2012

What is Currying?

Currying is like a kind of incremental binding of function arguments. Let’s define a simple function which takes 5 arguments:

def f(a, b, c, d, e):
    print(a, b, c, d, e)

In a language where currying is supported, f is a function which takes one argument (a) and returns a function which takes 4 arguments. This means that f(5) is the following function:

def g(b, c, d, e):
    f(5, b, c, d, e)

We could emulate this behavior the following way:

def f(a):
    def g(b, c, d, e):
        print(a, b, c, d, e)
    return g

f1 = f(1)
f1(2,3,4,5)

Now, f(1) returns the function g(b, c, d, e) which, for all b, c, d and e, behaves exactly like f(1, b, c, d, e). Since g is a function, it should support currying as well. Then what we really want is this:

def f(a):
    def g(b):
        def h(c, d, e):
            print(a, b, c, d, e)
        return h
    return g

f(1)(2)(3,4,5)

So what is f? It’s a function which takes an argument a and returns another function g which “remembers” that argument. The same way, g is a function which takes an argument b and returns another function h which “remembers” that argument. Basically, h is the function obtained from f by binding the first two arguments to a and b, respectively. Note that the last line of the code above is equivalent to this:

f1 = f(1)
f12 = f1(2)
f12(3,4,5)

I know what you’re thinking: we aren’t done yet! Here’s the final version in all its glory:

def f(a):
    def g(b):
        def h(c):
            def i(d):
                def j(e):
                    print(a, b, c, d, e)
                return j
            return i
        return h
    return g

f(1)(2)(3)(4)(5)

That’s currying for you!

But what about f(1, 2, 3, 4, 5)? Ops… f doesn’t take 5 arguments anymore :(

Well, there are some functional languages (Haskell, for instance) where function application is indicated by juxtaposition. For instance, f(a, b, c, d, e) would be expressed like this:

f a b c d e

Does that mean that f is a function which takes 5 arguments? Nope. Like in the example above, f is a function which takes an argument and returns a function which takes an argument, and so on… Let’s see it this way: the space (‘ ‘) is an operator that applies the right operand to the left operand (a function). The space is left-associative, meaning that the expression above is grouped this way:

((((f a) b) c) d) e

If you don’t like the operator space, let’s use some other operator:

((((f<-a)<-b)<-c)<-d)<-e

The operator ‘<-‘ passes its right operand to its left operand, which is a function. Indeed, these are functions:

  • f
  • f<-a
  • (f<-a)<-b
  • ((f<-a)<-b)<-c
  • (((f<-a)<-b)<-c)<-d

We can remove the brackets, of course:

f<-a<-b<-c<-d<-e

What am I trying to say? Simply that classic currying is better left to languages which use juxtaposition to pass arguments to functions.

Why currying?

Simple enough: it’s an easy way to get specialized functions from more general functions. Here are a few simple examples:

  • times3 = myMul(3)
    • Usage:
      listTimes3 = map(times3, list)          # evaluates myMul(3, x) for all x in list
  • printErrMsg = printMsg(“Error”)
    printWrnMsg = printMsg(“Warning”)
    • Usage:
      printErrMsg(“Divide by Zero!”)
      printWrnMsg(“You should save your document.”)
  • lexicSort = mySort(lexicCmp)
    • Usage:
      lexicSort(listOfStrings)

Currying VS partial application

Partial application is partial binding:

from functools import partial

def f(a, b, c, d):
    print(a, b, c, d)

g = partial(f, 1, 2, 3)
g(4)

Looks familiar?

Is that currying? Not quite. It’s more like manual currying:

from functools import partial

def f(a, b, c, d):
    print(a, b, c, d)

g = partial(partial(partial(f, 1), 2), 3)
g(4)

Is currying right for us?

We said that, in Haskell, we can write

f a b c d e

while, in Python, we would have to write

f(1)(2)(3)(4)(5)

I don’t know about you, but I prefer this one:

f(1, 2, 3, 4, 5)

We should also keep in mind that Python has keyword argument, so this should also be allowed:

f(1, 2, e = 5)

Is that real currying? It depends on whom you ask. But if you like it and find it useful, who cares?

Incremental binding

So let’s call it incremental binding because we can bind arguments in any order we like and how many arguments we feel like at once. Let’s pretend we have implemented a magical function called cur. Here’s a simple function we can use to test cur:

# Simple Function.
def func(a, b, c, d, e, f, g = 100):
    print(a, b, c, d, e, f, g)

Now let’s create a curried version of that function:

# f is the "curried" version of func.
f = cur(func)

Now we would like to be able to write the following (and get the expected result, of course!):

c1 = f(1)
c2 = c1(2, d = 4)               # Note that c is still unbound
c3 = c2(3)(f = 6)(e = 5)        # now c = 3
c3()                            # () forces the evaluation              <====
                                #   it prints "1 2 3 4 5 6 100"
c4 = c2(30)(f = 60)(e = 50)     # now c = 30
c4()                            # () forces the evaluation              <====
                                #   it prints "1 2 30 4 50 60 100"

The lines that print to the screen are marked with the double arrow ‘<====’.

The first line binds the argument a to 1 and save this less-free (i.e. more-bounded) function in c1. Now we have two currying-able functions: f and c1.

Here comes the interesting part: the second line binds b to 4 and d to 4, while c remains unbounded. Basically, at each step we can bind some positional arguments and some keyword arguments. If we have bound 2 arguments so far, then the 3rd argument will be the next positional argument to be bound. Indeed, line 3 binds c to 3 and then binds the two keyword arguments f and e (the order doesn’t matter, as you can see). Note that I talk of positional arguments and keyword arguments, but here positional and keyword don’t describe the arguments but the way the arguments are bound.

Note how line 4 forces the evaluation.

But… shouldn’t the evaluation be automatic? I would say no and here’s why.

As you may know, a function is called Referentially Transparent if its behavior depends on its arguments alone. This means that a function without arguments is a constant: it behaves the same way every time. Since functions don’t have side effects in pure functional languages, “to behave” is not the most appropriate verb: I should have said that a constant function always returns the same value. Now let’s look at c3 in the previous example. Is that a function? In a pure functional language the answer would be no: it’s a constant. But that means we can’t evaluate a function more than once. As you can see, func has side effects: it prints to the screen! Why shouldn’t we be able to call c3 more than once? So, manual evaluation is really a feature! Learning from other languages is a good think but trying to emulate the same exact behavior at all cost is a mistake.

The last two lines in the example above show another binding starting from c2.

Here’s our second example:

f = curr(func)                  # f is a "curried" version of func
                                # curr = cur with possibly repeated
                                #   keyword args
c1 = f(1, 2)(3, 4)
c2 = c1(e = 5)(f = 6)(e = 10)() # ops... we repeated 'e' because we     <====
                                #   changed our mind about it!
                                #   again, () forces the evaluation
                                #   it prints "1 2 3 4 10 6 100"

What is curr? It’s a variant of cur which lets you rebind arguments by keyword. In the example above, we first bind e to 5 and then rebind it to 10. Note that, even if that’s a one-line statement, we’re binding in 3 different steps. Finally, we forces the evaluation with ‘()’.

But I can feel some of you still want automatic evaluation (Why???). Ok, me too. So how about this:

f = cur(func, 6)        # forces the evaluation after 6 arguments
c1 = f(1, 2, 3)         # num args = 3
c2 = c1(4, f = 6)       # num args = 5
c3 = c2(5)              # num args = 6 ==> evalution                    <====
                        #   it prints "1 2 3 4 5 6 100"
c4 = c2(5, g = -1)      # num args = 7 ==> evaluation                   <====
                        #   we can specify more than 6 arguments, but
                        #   6 are enough to force the evaluation
                        #   it prints "1 2 3 4 5 6 -1"

You see how nicely default-valued arguments are handled? In the example above, 6 arguments are enough to force the evaluation, but we can also define all 7 of them (if we’re quick enough!). If we feel like it, we could even use some reflection to find out the number of non default-valued arguments (or at least I think so: I’ve been programming in Python for just 3 days and I haven’t looked at reflection yet). I don’t feel like it, but it’s simple to write a little wrapper around cur and curr.

Are we sure that our (still imaginary) cur works like expected? Let’s try something more convoluted:

def printTree(func, level = -1):
    if level == -1:
        printTree(cur(func), level + 1)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(f)

What does it do? It prints the first 64 non negative integers in base two. Here’s the exact output:

0 0 0 0 0 0
0 0 0 0 0 1
0 0 0 0 1 0
0 0 0 0 1 1
0 0 0 1 0 0
0 0 0 1 0 1
0 0 0 1 1 0
0 0 0 1 1 1
0 0 1 0 0 0
0 0 1 0 0 1
0 0 1 0 1 0
0 0 1 0 1 1
0 0 1 1 0 0
0 0 1 1 0 1
0 0 1 1 1 0
0 0 1 1 1 1
0 1 0 0 0 0
0 1 0 0 0 1
0 1 0 0 1 0
0 1 0 0 1 1
0 1 0 1 0 0
0 1 0 1 0 1
0 1 0 1 1 0
0 1 0 1 1 1
0 1 1 0 0 0
0 1 1 0 0 1
0 1 1 0 1 0
0 1 1 0 1 1
0 1 1 1 0 0
0 1 1 1 0 1
0 1 1 1 1 0
0 1 1 1 1 1
1 0 0 0 0 0
1 0 0 0 0 1
1 0 0 0 1 0
1 0 0 0 1 1
1 0 0 1 0 0
1 0 0 1 0 1
1 0 0 1 1 0
1 0 0 1 1 1
1 0 1 0 0 0
1 0 1 0 0 1
1 0 1 0 1 0
1 0 1 0 1 1
1 0 1 1 0 0
1 0 1 1 0 1
1 0 1 1 1 0
1 0 1 1 1 1
1 1 0 0 0 0
1 1 0 0 0 1
1 1 0 0 1 0
1 1 0 0 1 1
1 1 0 1 0 0
1 1 0 1 0 1
1 1 0 1 1 0
1 1 0 1 1 1
1 1 1 0 0 0
1 1 1 0 0 1
1 1 1 0 1 0
1 1 1 0 1 1
1 1 1 1 0 0
1 1 1 1 0 1
1 1 1 1 1 0
1 1 1 1 1 1

Was it really necessary to cut and paste all that? Probably not… I could’ve written it by hand.

Here’s the code again (that’s what long unneeded listings do to your articles):

def printTree(func, level = -1):
    if level == -1:
        printTree(cur(func), level + 1)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(func)

Our printTree is a function which takes a function func and an optional level which defaults to -1. There are 3 cases:

  • level is –1:
    • we didn’t call ourselves, so someone has just called us and passed to us (we hope so) a “normal” function;
    • we call ourselves recursively, but this time the way we like: with a curried version of func and with level 0. That’s better.
  • level is 6:
    • we called ourselves recursively many times and we’re finally at level 6;
    • our (very own private) func should have the first 6 arguments already bounded;
    • we bind the last argument gto something that won’t show up on the screen: the empty string(if you’re wondering why I disregard the last argument this way, you should try to include a monotonous 128-line output in one of your articles!);
    • we evaluate func and print a number on the screen!
  • level is between 0 and 5 (limits included):
    • our func is a function which have the first level arguments bound and the remaining arguments unbound;
    • we can’t evalute func: we need more arguments;
    • we control the argument x in position level;
    • we bound x to 0 and let another instance of printTree handle the remaining arguments;
    • we now bound x to 1 and call yet another instance of printTree.

If you hate recursion, you can very well skip this example and go to the last one which is… ops… recursive as well:

def f2(*args):
    print(", ".join(["%3d"%(x) for x in args]))

def stress(f, n):
    if n: stress(f(n), n - 1)
    else: f()               # enough is enough

stress(cur(f2), 100)

It prints this:

100,  99,  98,  97,  96,  95,  94,  93,  92,  91,  90,  89,  88,  87,  86,  85,
 84,  83,  82,  81,  80,  79,  78,  77,  76,  75,  74,  73,  72,  71,  70,  69,
 68,  67,  66,  65,  64,  63,  62,  61,  60,  59,  58,  57,  56,  55,  54,  53,
 52,  51,  50,  49,  48,  47,  46,  45,  44,  43,  42,  41,  40,  39,  38,  37,
 36,  35,  34,  33,  32,  31,  30,  29,  28,  27,  26,  25,  24,  23,  22,  21,
 20,  19,  18,  17,  16,  15,  14,  13,  12,  11,  10,   9,   8,   7,   6,   5,
  4,   3,   2,   1

Why did I use recursion? Because I was so caught up in implementing cur that I forgot Python has loop constructs :)

Beginners shouldn’t have too much of a problem with this example and I’m getting a little bored myself, so let’s proceed!

Wonderful, but how do we implement cur???

Let’s proceed a step at a time.

We might implement a function which remembers the arguments it has received so far:

myArgs = []
def f(*args):
    global myArgs           # we will refer to the global var
    if len(args):           # some more args!
        myArgs += args
    else:                   # time to evaluate...
        print(*myArgs)

f(1,2)
f(3,4)
f()

The last line prints “1 2 3 4”.

Our function asks to receive an arbitrary number of arguments packed in the local array args and then there are two cases:

  1. if len(args) is non-zero, we have received some arguments from the caller so we append them to our global array myArgs;
  2. if len(args) is zero, the caller wants to force the evaluation so we act as if we were called with all the arguments in myArgs from the start.

Any objections? Like

  1. Why aren’t we writingg = f(1,2)but justf(1,2)?
  2. What aboutf(1,2)(3,4)()?
  3. What if we want two different bindings of f at the same time?

Objection 2 is easy to deal with:

myArgs = []
def f(*args):
    global myArgs           # we will refer to the global var
    if len(args):           # some more args!
        myArgs += args
        return f
    else:                   # time to evaluate...
        print(*myArgs)

f(1, 2)(3, 4)()

We just needed to add return f so now f(1, 2) returns f and so on…

What about objection 3? That is strictly related the first objection: we don’t need assignments because we are allowing a single global binding per function and that defeats the purpose of having currying. The solution is easy: let’s give each partially bounded f its private myArgs.

We can do that by wrapping myArgs and f inside another function so that myArgs is not shared anymore:

def g(*args):
    myArgs = []
    def f(*args):
        nonlocal myArgs         # now it's non-local: it isn't global
        if len(args):           # some more args!
            myArgs += args
            return f
        else:                   # time to evaluate...
            print(*myArgs)
    return f(*args)

g1 = g(1, 2)
g2 = g(10, 20)
g1(3, 4)()
g2(30, 40)()

The last two lines print “1 2 3 4” and “10 20 30 40”, respectively. We did it!

Let’s try something a little different:

a = g(1)
b = g(2)
a1 = a(1, 1)()
a2 = a(2, 2)()

It prints “1 1 1” and… ops… “1 1 1 2 2”. That’s right: functions a and b behave exactly like our old f with its global myArgs. It’s true that myArgs isn’t global anymore, but while a and b receive two different myArgs, every successive use of a will use that same array. So the last line adds 2, 2 to a’s myArgs which already contains 1, 1, 1. The second line is just there for symmetry :)

How do we solve this problem? Hmm… more recursion!!! … Nope… that wouldn’t work. How about less recursion?

Let’s look at the code again: “a = g(1)” returns an f-type (so to speak) function so the next “a(1, 1)” doesn’t work correctly (we’ve just postponed the problem we were facing from the start with f). And if a itself returned “g(1, 1, 1)”? It’d be as if the caller had performed a single binding by calling g. History would be forgotten and totally irrelevant. Every binding would look like the first and only binding!

Before proceeding, we should get rid of another flaw. Look at this example:

a = g(1)
b = g(2)
a(1, 1)
a(2, 2)()

This prints “1 1 1 2 2”. Our old f, here referred to by a, still likes to remember things it shouldn’t. We can solve both problems at the same time:

def g(*args):
    myArgs = args
    def f(*args):
        nonlocal myArgs
        if len(args):           # some more args!
            return g(*(myArgs + args))
        else:                   # time to evaluate...
            print(*myArgs)
    return f

a = g(1)
b = g(2)
a(1, 1)
a(2, 2)()

Now, g gives each f its arguments and no f can change them in any way. When f receives more arguments (args), it creates and returns a new version of f which includes all the arguments seen so far. As you can see, the new f is created by calling g one more time at line 6.

We can do even better:

def g(*myArgs):
    def f(*args):
        if len(args):           # some more args!
            return g(*(myArgs + args))
        else:                   # time to evaluate...
            print(*myArgs)
    return f

The argument args was already a local variable of g (in a sense), so why create another local variable myArgs? Also, we don’t need nonlocal because we won’t modify myArgs: we’ll just read it.

We still miss a step: I hope you won’t code any single function by hand like that! We should implement cur as a curried version of a general function. That’s easy enough:

def cur(func):
    def g(*myArgs):
        def f(*args):
            if len(args):           # some more args!
                return g(*(myArgs + args))
            else:                   # time to evaluate...
                func(*myArgs)
        return f
    return g

def f(a, b, c, d, e):
    print(a, b, c, d, e)
cf = cur(f)

a = cf(1)
b = cf(2)
a(1, 1)
a(2, 2)(3, 4)()

Now we can include keyword arguments:

def cur(func):
    def g(*myArgs, **myKwArgs):
        def f(*args, **kwArgs):
            if len(args) or len(kwArgs):    # some more args!
                newArgs = myArgs + args
                newKwArgs = dict.copy(myKwArgs);
                newKwArgs.update(kwArgs)
                return g(*newArgs, **newKwArgs)
            else:                           # time to evaluate...
                func(*myArgs, **myKwArgs)
        return f
    return g

def f(a, b, c, d, e):
    print(a, b, c, d, e)
cf = cur(f)

a = cf(1, e = 10)
b = cf(2)
a(1, 1)
a(2, d = 9)(3)()

Final Version

Now you should be able to understand the final version, which has some additional features. Here’s the entire code (examples included):

# Coded by Massimiliano Tomassoli, 2012.
#
# - Thanks to b49P23TIvg for suggesting that I should use a set operation
#     instead of repeated membership tests.
# - Thanks to Ian Kelly for pointing out that
#     - "minArgs = None" is better than "minArgs = -1",
#     - "if args" is better than "if len(args)", and
#     - I should use "isdisjoint".
#
def genCur(func, unique = True, minArgs = None):
    """ Generates a 'curried' version of a function. """
    def g(*myArgs, **myKwArgs):
        def f(*args, **kwArgs):
            if args or kwArgs:                  # some more args!
                # Allocates data to assign to the next 'f'.
                newArgs = myArgs + args
                newKwArgs = dict.copy(myKwArgs)

                # If unique is True, we don't want repeated keyword arguments.
                if unique and not kwArgs.keys().isdisjoint(newKwArgs):
                    raise ValueError("Repeated kw arg while unique = True")

                # Adds/updates keyword arguments.
                newKwArgs.update(kwArgs)

                # Checks whether it's time to evaluate func.
                if minArgs is not None and minArgs <= len(newArgs) + len(newKwArgs):
                    return func(*newArgs, **newKwArgs)  # time to evaluate func
                else:
                    return g(*newArgs, **newKwArgs)     # returns a new 'f'
            else:                               # the evaluation was forced
                return func(*myArgs, **myKwArgs)
        return f
    return g

def cur(f, minArgs = None):
    return genCur(f, True, minArgs)

def curr(f, minArgs = None):
    return genCur(f, False, minArgs)

# Simple Function.
def func(a, b, c, d, e, f, g = 100):
    print(a, b, c, d, e, f, g)

# NOTE: '<====' means "this line prints to the screen".

# Example 1.
f = cur(func)                   # f is a "curried" version of func
c1 = f(1)
c2 = c1(2, d = 4)               # Note that c is still unbound
c3 = c2(3)(f = 6)(e = 5)        # now c = 3
c3()                            # () forces the evaluation              <====
                                #   it prints "1 2 3 4 5 6 100"
c4 = c2(30)(f = 60)(e = 50)     # now c = 30
c4()                            # () forces the evaluation              <====
                                #   it prints "1 2 30 4 50 60 100"

print("\n------\n")

# Example 2.
f = curr(func)                  # f is a "curried" version of func
                                # curr = cur with possibly repeated
                                #   keyword args
c1 = f(1, 2)(3, 4)
c2 = c1(e = 5)(f = 6)(e = 10)() # ops... we repeated 'e' because we     <====
                                #   changed our mind about it!
                                #   again, () forces the evaluation
                                #   it prints "1 2 3 4 10 6 100"

print("\n------\n")

# Example 3.
f = cur(func, 6)        # forces the evaluation after 6 arguments
c1 = f(1, 2, 3)         # num args = 3
c2 = c1(4, f = 6)       # num args = 5
c3 = c2(5)              # num args = 6 ==> evalution                    <====
                        #   it prints "1 2 3 4 5 6 100"
c4 = c2(5, g = -1)      # num args = 7 ==> evaluation                   <====
                        #   we can specify more than 6 arguments, but
                        #   6 are enough to force the evaluation
                        #   it prints "1 2 3 4 5 6 -1"

print("\n------\n")

# Example 4.
def printTree(func, level = None):
    if level is None:
        printTree(cur(func), 0)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(func)

print("\n------\n")

def f2(*args):
    print(", ".join(["%3d"%(x) for x in args]))

def stress(f, n):
    if n: stress(f(n), n - 1)
    else: f()               # enough is enough

stress(cur(f2), 100)

That’s all!

Please feel free to let me know what you think by leaving a comment!

About these ads

5 Responses to “Currying in Python”

  1. Ling said

    Don’t you need to add “set” for “if unique and not set(kwArgs.keys()).isdisjoint(newKwArgs):”?

  2. smallmall said

    I’ve done similar solution but using inspection, I had quite a lot of problems with different types of arguments and methods, I didn’t tested your solution that thoroughly though.
    http://code.activestate.com/recipes/577928-indefinite-currying-decorator-with-greedy-call-and/

  3. Trudy said

    Thank you a bunch for sharing this with all folks you really know what you’re speaking about!
    Bookmarked. Kindly additionally visit my web site =).

    We may have a hyperlink trade agreement between
    us

  4. Alexander Alvonellos said

    I liked your article so much that I even clicked an ad for you. I hope that helps. This is a very deep and thorough explanation of the concept and was exactly what I was looking for. Thank you so much.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

 
Follow

Get every new post delivered to your Inbox.

%d bloggers like this: