Computer Science

Here you'll find some information about programming languages and software development in general.

Code blocks in Python

Posted by mtomassoli on April 20, 2012

Introduction

As anyone knows, Python doesn’t support code block objects: you can’t create and pass around anonymous blocks of code. That’s a pity and in fact many coders have suggested workarounds. Some tried to introduce proper code blocks by manipulating byte codes, while others proposed to use lambdas with concatenated expressions so as to emulate multi-statement blocks.

Unfortunately, these attempts, while very interesting and clever, can’t be used in production code because of their many limitations and oddities. I tried many approaches myself, but none was perfect: one didn’t support ‘nonlocal’, the other was slow, and so on…

Finally, I went for source code rewriting and I’m quite satisfied with the result.

Here are the main features:

  • code blocks act as normal functions/closures (you can use return, yield, global and nonlocal inside of them)
  • the source code is rewritten, on-the-fly, in RAM: no files are created or modified
  • syntax errors related to the new syntax generate meaningful error messages
  • error messages refer to your code and not the module codeblocks
  • debuggers break on your code as it should be
  • stepping and tracing through your code behave as expected
  • codeblocks doesn’t mess with the importation process
  • codeblocks doesn’t manipulate byte code and should be quite portable
  • codeblocks works with both Python 2.7 and Python 3.2 (but nonlocal is not supported in Python 2.7)

Repository

You can download (and contribute to, if you wish) this module from bitbucket.

The docstring of the module offers a more technical documentation and the exact rewriting rules applied to the original code.

A little example

def each(iterable, block):
    for e in iterable:
        block(e)                        # step into -> 6:

with each(range(0, 10)) << 'x':
    print('element ' + str(x))

That with statement is a special statement that

  1. creates an anonymous function anon_func which takes an argument x and has line 6 as its body,
  2. passes anon_func as a positional argument to the function each,
  3. executes each with all its arguments (the one given at line 5 and the block)

If you issue the step into command when at line 3, you’ll immediately jump to line 6. The same way, step into will take you from line 5 to line 2. Any other command (step over, step out, …) will also work as expected.

Usage

Here’s the previous example with all the ugly stuff included:

import codeblocks

codeblocks.rewrite()

def each(iterable, block):
    for e in iterable:
        block(e)                        # step into -> 6:

with each(range(0, 10)) << 'x':
    print('element ' + str(x))

That’s all. Just make sure that you import codeblocks and call rewrite right at the start of your code. What really happens is that rewrite rewrites the code, executes it and quits. This means that you should put a breakpoint over a line of code which follows your invocation of rewrite. If you steeped into rewrite you’d come across a call to exec. Stepping into it would also work, but setting a breakpoint is the recommended way.

If you want to use codeblocks in a module of yours, you’ll need to call end_of_module at the end of your module. If you forget it, an exception will be raised during the rewriting. The code above, if in a module, would thus look as follows:

import codeblocks

codeblocks.rewrite()

def each(iterable, block):
    for e in iterable:
        block(e)                        # step into -> 6:

with each(range(0, 10)) << 'x':
    print('element ' + str(x))

codeblocks.end_of_module()

Nothing too invasive, I hope.

Creating a block

Unfortunately, Python syntax doesn’t let us create blocks as part of normal expressions: we need to use a statement which already takes a block of code. I chose with over for, if, etc…

Here’s the fastest way to create a block and assigning it to a variable:

def id(x): return x

with my_block << id() << 'x, y = 3':
    print(x, y)

my_block(4)
my_block(1, 2)
print(my_block)

which prints

4 3
1 2
<function codeblock_1b4c94d3_1 at 0x02E7AC90>

I could’ve added special syntax for this very case, but I don’t think it’s worth it.

Bug in Python

I’ve just discovered a bug in Python. A patch is already available, but your version could still suffer from this bug.

Here’s the bug. Create a file script.py with this content:

with open('test') as f:
	data = f.read()
with open('test') as f:
	data = f.read()

Now start debugging it with a debugger of your choice and, from the first line, jump to line 3. If it crashes, the bug is still there.

The easier way to see this is by using pdb. Run it with python –m pdb script.py and, once in the debugger, issue the command j 3.

This bug is triggered only if you use with statements at the module level in imported modules which use codeblocks. Because this is not too big a limitation, I’ve decided to insert a bug check in codeblocks which raise an exception (with a meaningful error message) if the user write code that would trigger this bug. When the bug is finally gone from Python official releases, I’ll remove the bug check for patched versions.

Let’s move on.

Pseudo-keywords

When Python doesn’t let you use a keyword inside a with statement, you can use a pseudo-keyword.

Here’s the short list of keywords and corresponding pseudo-keywords:

keyword

pseudo-keyword

global identifiers _global(identifiers)
nonlocal identifiers _nonlocal(identifiers)
return expression _return(expression)
yield expression _yield(expression)

Passing a block as a positional argument

To pass a block as a positional argument to a function you use this syntax:

with each(range(0, 10)) << 'x':
    print('element ' + str(x))

where each is called as each(range(0, 10), block).

Passing a block as a keyword argument

To pass the same block as a keyword argument of name block_name, you use a similar syntax:

with each(range(0, 10)) << block_name << 'x':
    print('element ' + str(x))

or also

with each(range(0, 10)) << 'block_name' << 'x':
    print('element ' + str(x))

This calls each as each(range(0, 10), block_name = block).

Notice that block_name can be a literal or an identifier, but the arguments for the block (just x in this example) must be a literal. There are two reasons for this:

  1. It’s harder for a normal with to be misidentified as a special with and then erroneously rewritten.
  2. Default-valued arguments aren’t allowed by Python’s syntax so a literal is needed in that case

All the other parameters (such as block_name, in the example) can be both literals and identifiers.

Passing a block which takes no arguments

If the block takes no arguments, you should use the empty literal:

with each(range(0, 10)) << block_name << '':
    print('element ' + str(1))

If you don’t like it, you can use _ (underscore) or None:

with each(range(0, 10)) << block_name << _
    print('element ' + str(1))

with each(range(0, 10)) << block_name << None
    print('element ' + str(1))

Passing more than one block to a function

If a function takes more than one block, you need to use a multi with:

def take3(block1, block2, block3):
    for b in (block1, block2, block3):
        print('=== {} ===\n{}'.format(b, b()))

with take3() << ':multi':
    with '':                    # 1st pos arg
        _return("I'm block 1!")
    with block3 << '':
        _return("I'm block 3!")
    with '':                    # 2nd pos arg
        _return("I'm block 2!")

This prints

=== <function codeblock_1b4c94d3_6 at 0x02E5A930> ===
I'm block 1!
=== <function codeblock_1b4c94d3_8 at 0x02E5AE40> ===
I'm block 2!
=== <function codeblock_1b4c94d3_7 at 0x02E5AA98> ===
I'm block 3!

The last three with statements are internal withs and have a simpler syntax: they can only take an optional block_name and a required code_args, i.e. the arguments on the right of functions in non-internal withs.

In the example above, take3 is called as take3(anon_block1, anon_block2, block3 = anon_block3).

Note that the keywords can be arbitrary names in some circumstances:

def take3(**kwargs):
    for name in kwargs:
        print('=== {} ===\n{}'.format(name, kwargs[name]()))

with take3() << ':multi':
    with 'num #1' << '':             # arbitrary strings!
        _return("I'm block 1!")
    with 'num #3' << '':
        _return("I'm block 3!")
    with 'num #2' << '':
        _return("I'm block 2!")

This prints

=== num #3 ===
I'm block 3!
=== num #2 ===
I'm block 2!
=== num #1 ===
I'm block 1!

Passing one or more blocks to a function as a list

You can also pass one or more blocks to a function as a single list:

def take3(title, blocks):
    print(title)
    for b in blocks:
        print('    ' + b())

with take3('List of blocks:') << ':list':
    with '':
        _return("I'm block 1!")
    with '':
        _return("I'm block 3!")
    with '':
        _return("I'm block 2!")

This prints

List of blocks:
    I'm block 1!
    I'm block 3!
    I'm block 2!

Internal withs can take only one argument, in this case. This is what happens if you give one of them two arguments:

  File "C:\...\test.py", line 124
    with wrong_arg << '':             # arbitrary strings!
         ^
SyntaxError: :list forbids internal Withs with codekw args.

Passing one or more blocks to a function as a dictionary

You can pass one or more blocks to a function as a single dictionary as well:

def take3(dict, silent = True):
    for key in sorted(dict.keys()):
        print('=== {} ===\n{}'.format(key, dict[key]()))

with take3(silent = False) << ':dict':
    with 'block 1' << '':
        _return("I'm block 1!")
    with 'block 3' << '':
        _return("I'm block 3!")
    with 'block 2' << '':
        _return("I'm block 2!")

This prints:

=== block 1 ===
I'm block 1!
=== block 2 ===
I'm block 2!
=== block 3 ===
I'm block 3!

Here’s another example:

import logging
import random

logging.basicConfig(level = logging.DEBUG)

with operations << dict() << ':dict':
    with add << 'x, y':
        logging.debug('doing {} + {}'.format(x, y))
        _return(x + y)
    with sub << 'x, y':
        logging.debug('doing {} - {}'.format(x, y))
        _return(x - y)
    with mul << 'x, y':
        logging.debug('doing {} * {}'.format(x, y))
        _return(x * y)
    with div << 'x, y':
        logging.debug('doing {} / {}'.format(x, y))
        _return(x / y)

opnd1 = random.randint(1, 100)
opnd2 = random.randint(1, 100)
op = random.choice(('add', 'sub', 'mul', 'div'))
print(operations[op](opnd1, opnd2))

Here are a few runs:

DEBUG:root:doing 10 / 87
0.11494252873563218

DEBUG:root:doing 34 + 50
84

DEBUG:root:doing 67 / 15
4.466666666666667

DEBUG:root:doing 5 - 51
-46

DEBUG:root:doing 10 * 52
520

You got the idea.

The result of the function in a with can be assigned to variables

You can assign the result of the function called in a with statement, but you can’t just use the assignment operator:

def take3(list, silent = True):
    for i in range(len(list)):
        print('=== {} ===\n{}'.format('block ' + str(i + 1), list[i]()))
    return list[0](), [list[1](), list[2]()]           # just a test

with (r1, [r2, r3]) << take3(silent = False) << list << ':list':
    with '':
        _return("I'm block 1!")
    with '':
        _return("I'm block 2!")
    with '':
        _return("I'm block 3!")

print('---')
print(r1, r2, r3)

This prints

=== block 1 ===
I'm block 1!
=== block 2 ===
I'm block 2!
=== block 3 ===
I'm block 3!
---
I'm block 1! I'm block 2! I'm block 3!

With statements can be nested

Of course they can be nested! Here we go:

import re
import random

def take3(dict, silent = True):
    for key in sorted(dict.keys()):
        print('=== {} ===\n{}'.format(key, dict[key]()))

with take3(silent = False) << ':dict':
    with 'block 1' << '':
        _return("I'm block 1!")
    with 'block 3' << '':
        text = "I'm surely codeblock number three!"
        with ris << re.sub(r'(\w)(\w+)(\w)', string = text) << repl << 'm':
            # From Python's doc.
            inner_word = list(m.group(2))
            random.shuffle(inner_word)
            _return (m.group(1) + "".join(inner_word) + m.group(3))
        _return (ris)
    with 'block 2' << '':
        _return("I'm block 2!")

Here are a few runs:

=== block 1 ===
I'm block 1!
=== block 2 ===
I'm block 2!
=== block 3 ===
I'm seurly cbooledck nmuebr three!

=== block 1 ===
I'm block 1!
=== block 2 ===
I'm block 2!
=== block 3 ===
I'm srluey clbceoodk nubmer terhe!

=== block 1 ===
I'm block 1!
=== block 2 ===
I'm block 2!
=== block 3 ===
I'm srluey cbodoelck nmeubr there!

Code blocks allow global and nonlocal declarations

Because code blocks are implemented as real functions, you can use global and nonlocal in the usual way:

my_var = None

def each(iterable, block):
	for e in iterable:
		block(e)

def gen_funcs():
	sum = 0
	def acc_elems(*elems):
		# You would never use it like this: this is just a test, as always!
		with each(elems) << 'x':
			global my_var
			nonlocal sum
			sum += x
			my_var = 'modified from inside the block'
	def get_sum():
		return sum
	return get_sum, acc_elems

get_sum, acc_elems = gen_funcs()

print('my_var: ', my_var)
print('sum:   ', get_sum())

acc_elems(1, 2, 3)
print('after acc_elems(1, 2, 3):')
print('    my_var: ', my_var)
print('    sum:   ', get_sum())

acc_elems(4, 5, 6)
print('after acc_elems(4, 5, 6):')
print('    my_var: ', my_var)
print('    sum:   ', get_sum())

This prints

my_var:  None
sum:    0
after acc_elems(1, 2, 3):
    my_var:  None
    sum:    6
after acc_elems(4, 5, 6):
    my_var:  None
    sum:    21

Code blocks allow yield

No surprise here as well:

def print_all(gen):
    for msg in gen(verb = True):
        print(msg)

with print_all() << 'n = 10, e = 3, verb = False':
    if verb:
        for i in range(n):
            _yield('{}^{} is {}'.format(i, e, i ** e))
    else:
        for i in range(n):
            _yield(i, i ** e)

This prints

0^3 is 0
1^3 is 1
2^3 is 8
3^3 is 27
4^3 is 64
5^3 is 125
6^3 is 216
7^3 is 343
8^3 is 512
9^3 is 729

That’s all!

As always, comments and constructive criticism is greatly appreciated.

Posted in Python | Tagged: , , , , , , , | 16 Comments »

Pipelining in Python

Posted by mtomassoli on March 29, 2012

Pipelining in F#

F# is a multi-paradigm language which targets the .NET platform and is based on OCaml, which is a variant of ML. Even though F# supports imperative and object-oriented programming, it’s mainly a functional programming language.

Pipelining isn’t a new idea, no doubt about it. Unix and Linux folks have been using it for a very long time. It’s also known as forward-chaining and is basically a function application in reverse. Transformations can be chained with the operator ‘|>’ in F#.

Let’s start with a simple example:

Seq.filter (fun i -> i%2 <> 0) (Seq.map (fun i -> i*i) (seq {0..40}))

Here’s one way to express that in Python:

filter(lambda i : i%2, map(lambda i : i*i, range(41)))

We simply start with the list of the first 40 non-negative integers, square them and filter out the even ones. Therefore, we get an iterable which contains the following integers:

1, 9, 25, 49, 81, 121, 169, 225, 289, 361, 441, 529, 625, 729, 841, 961, 1089, 1225, 1369, 1521

Now let’s rewrite the F# version using pipelining:

seq {0..40} |> Seq.map (fun i -> i*i) |> Seq.filter (fun i -> i%2 <> 0)

The first thing we notice is that there are fewer parentheses, which is always a good thing (I probably shouldn’t have said that…). The next thing is that we can understand that line by reading it from left to right. I don’t know about you, but I had to reread the first version from right to left to fully understand it. Ok, that was simple and I wrote it so I already knew what it does, but consider the situation where we have some highly “overloaded” functions, i.e. functions which do different things depending on the type of the arguments (actually, we shouldn’t have such functions). Now how can you tell what a function does without knowing the right-side of the expression? That suggests that reading such an expression from left to right is less efficient.

As a side note, F# is a typed functional language which uses type inference so the programmer needs to provide types only when they can’t be derived by analyzing the expression. Well, F# likes pipelining because it lets type information flow from input objects to the functions which transform those objects into output objects.

Let’s get back to our expression: did you notice the currying?

Here’s how F# defines the operator ‘|>’:

let (|>) x f = f x

That’s all! It says that

arg |> func

is equivalent to

func arg

or, if you prefer,

func(arg)

Let’s analyze the first part of the code above:

seq {0..40} |> Seq.map (fun i -> i*i)

That must be equivalent to

Seq.map (fun i -> i*i) (seq {0..40})

It should be obvious that

Seq.map (fun i -> i*i)

is a curried version of Seq.map. We might say that the operator ‘|>’ lets us pass a function its last argument. By the way, now you know why map takes a function and a sequence in that particular order :)

Pipelining in Python

What should pipelining look like in Python? I don’t know, but here’s what it looks like in my code:

range(41) >> filter(lambda i : i%2) >> map(lambda i : i*i)

Not too shabby, eh? If you haven’t done that already, please read my previous post “Currying in Python” because filter and map in the code above are the currying-able versions of the built-in functions you usually use in your code.

Composing functions in F#

You can also compose functions in F#. The usual composition, i.e. the one your math teacher taught you, is defined (in math-like language) as follows:

(g o f)(x) := g(f(x))

where that ‘o’ should really be a small circle.

Since pipelining is a sort of function application in reverse, it’s only natural for us to reverse the composition of functions as well:

(f >> g)(x) := g(f(x))

Again, that’s not an F# definition: it’s just my way to explain it to you.

Yes, F# uses ‘>>’ for composition, but we’ve already used that for pipelining in Python. We’re going to use some other operator in Python. Here’s how function composition is used in F#:

let comp : int seq -> int seq = Seq.map (fun i -> i*i) >> Seq.filter (fun i -> i%2 <> 0)
seq {0..40} |> comp

Please ignore the type annotations (i.e. int seq –> int seq) or, better, let’s write it another way:

seq {0..40} |> (Seq.map (fun i -> i*i) >> Seq.filter (fun i -> i%2 <> 0))

So what’s the difference?

Basically, we aren’t passing a sequence “through” two individual functions anymore, but through their composition, i.e. through a single function. That’s useful when we want to reuse a “chain” of functions. For instance (in pseudo-code) we can rewrite

out1 = input1 |> func1 |> func2 |> func3 |> func4
out2 = input2 |> func1 |> func2 |> func3 |> func4
out3 = input3 |> func1 |> func2 |> func3 |> func4
out4 = input4 |> func1 |> func2 |> func3 |> func4

as

compFunc = func1 >> func2 >> func3 >> func4
out1 = input1 |> compFunc
out2 = input2 |> compFunc
out3 = input3 |> compFunc
out4 = input4 |> compFunc

which is much better, I think. The beauty of ‘>>’ is that it lets us compose functions in the same reversed order they appear in our pipelines.

Composing functions in Python

You won’t like this, I’m afraid:

compFunc = filter(lambda i : i%2) - map(lambda i : i*i)
range(0,41) >> compFunc

Yep: a minus. Why? Because I couldn’t think of anything better, but I’m open to suggestions. My first idea was to use a plus, then I used an asterisk and, finally, a minus. That minus is not really a minus but a sort of link which should remind us that functions are not composed the usual way.

Let’s create a few tests that our implementation should be able to handle.

Some tests

To make our tests a little more readable, we’ll import some definitions and redefine a few symbols:

import sys
from urllib.request import urlopen
from re import findall
from math import sqrt, floor
from functools import reduce

map = cur(map, 2)
filter = cur(filter, 2)
urlopen = cur(urlopen)
findall = cur(findall)
my_print = cur(lambda list : print(*list))
reduce = cur(reduce, 2)

As I said in my previous article, cur is a function which takes another function and returns a curried version of it. The function cur also accepts some optional arguments such as minArgs which tells cur to force the evaluation of the curried function after at least minArgs arguments have been provided.

Test 1 (Example 5, in the code)

range(0,50) >> filter(lambda i : i%2) >> map(lambda i : i*i) >> my_print

It prints

1 9 25 49 81 121 169 225 289 361 441 529 625 729 841 961 1089 1225 1369 1521 168
1 1849 2025 2209 2401

Test 2 (Example 6, in the code)

Now we check that function composition works as expected:

compFunc = filter(lambda i : i%2) - map(lambda i : i*i)
range(0,50) >> compFunc >> my_print

It prints the same list of numbers of Test 1.

Test 3 (Example 7, in the code)

This test is more complex than the first two:

# Tells whether x is not a proper multiple of n.
notPropMult = cur(lambda n, x : x <= n or x % n, 2)

def findPrimes(upTo):
    if (upTo <= 5): return [2, 3, 5]
    filterAll = (findPrimes(floor(sqrt(upTo)))
                 >> map(lambda x : filter(notPropMult(x)))
                 >> reduce(lambda f, g : f - g))
    return list(range(2, upTo + 1)) >> filterAll

findPrimes(1000) >> my_print

This prints:

2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53 59 61 67 71 73 79 83 89 97 101 103 1
07 109 113 127 131 137 139 149 151 157 163 167 173 179 181 191 193 197 199 211 2
23 227 229 233 239 241 251 257 263 269 271 277 281 283 293 307 311 313 317 331 3
37 347 349 353 359 367 373 379 383 389 397 401 409 419 421 431 433 439 443 449 4
57 461 463 467 479 487 491 499 503 509 521 523 541 547 557 563 569 571 577 587 5
93 599 601 607 613 617 619 631 641 643 647 653 659 661 673 677 683 691 701 709 7
19 727 733 739 743 751 757 761 769 773 787 797 809 811 821 823 827 829 839 853 8
57 859 863 877 881 883 887 907 911 919 929 937 941 947 953 967 971 977 983 991 9
97

A little bit of math

As all of you already know, to test whether n is a prime we can try to divide it by the integers 2,3, …, n-1.
As some of you already know, we should really limit ourselves to the integers 2,3, …, floor(sqrt(n)). Why is that? Because if some integer x divides n and x > sqrt(n), then there must be some positive integer y < sqrt(n) which divides n. So, if we didn’t find any y < sqrt(n) which divides n, then we can’t possibly find any x > sqrt(n).

The Sieve of Eratosthenes consists in taking the numbers 2, 3, …, upTo and filtering out all the numbers which aren’t prime numbers. Starting by 2, we cross out all multiples of 2 greater than 2. That gives us 3, 5, 7, 9, 11, …, upTo (assuming that upTo is odd). We know for sure that 3 is a prime so we keep it and start crossing out all multiples of 3 greater than 3, and so on…

When do we stop? We stop when we reach floor(sqrt(upTo)).

If you’ve already looked at the code above, you should have noticed that it’s recursive. Why? Because I like recursion and I wanted to do something more interesting than the usual sieve. Here’s what I came up with:

the primes up to upTo are

  • (if upTo <= 5)      [2, 3, 5]
  • (otherwise)           [2, …, upTo] without the proper multiples of the primes up to floor(sqrt(upTo))

where a proper multiple of x is a multiple of x greater than x (we don’t want to cross out the prime numbers themselves!). Indeed, to tell if a number n is a prime, we don’t need to consider all the divisors up to floor(sqrt(n)), but just the prime numbers up to floor(sqrt(n)). If you think about it, that’s exactly what the iterative Sieve of Eratosthenes already does. With recursion, for once, we need to be a little more explicit.

So, if upTo = 20, we have:

    primes up to 20 =
    [2, …, 20] without the proper multiples of the primes up to 4 =
    [2, …, 20] without the proper multiples of [2, 3, 5] =
    [2, 3, 5, 7, 11, 13, 17, 19]

That 5 wasn’t needed, but I decided to pick that as the base case.

The actual implementation

This is the interesting part:

    filterAll = (findPrimes(floor(sqrt(upTo)))
                 >> map(lambda x : filter(notPropMult(x)))
                 >> reduce(lambda f, g : f - g))

Let’s see what it does by looking at an example. Let assume that findPrimes(floor(sqrt(upTo))) is the list [2, 3, 5, 7]. Here, ‘->’ means “is transformed into” and corresponds to the occurrences of the ‘>>’ operator in the code above. I will ignore the differences between lists and iterators for the sake of clarity:

    [2, 3, 5, 7] –>
    [filter(notPropMult(2)), filter(notPropMult(3)), filter(notPropMult(5)), filter(notPropMult(7))] –>
    filter(notPropMult(2)) – filter(notPropMult(3)) – filter(notPropMult(5)) – filter(notPropMult(7))

So, filterAll is really the filter we need to select all the prime numbers from the first upTo integers (starting from 2).

Test 4 (Example 8, in the code)

This test is easy but interesting nonetheless:

def do(proc, arg):
   proc()
   return arg
do = cur(do)

cprint = cur(print)

("http://python.org"
 >> do(cprint("The page http://python.org has about... ", end = ''))
 >> do(sys.stdout.flush)
 >> urlopen
 >> cur(lambda x : x.read())
 >> findall(b"href=\"")
 >> cur(len)
 >> cur("{} hrefs.".format)
 >> cprint)

It prints

The page http://python.org has about... 121 hrefs.

with a small pause after “about…” (that delay is due to urlopen).

The function do is basically an identity transformation with a side-effect. It lets us do something (like printing to the screen) without breaking or altering the flow. I won’t point out that that code is full of currying… ops. On second thought, I think that code doesn’t even need an explanation: every Python programmer should be able to understand what the code does and fill in the details by himself or herself.

How do we implement Pipelining?

That’s very easy. We need to write a class to overload the operators ‘>>’ and ‘-’ so it makes sense to rewrite the code in the article “Currying in Python” like this:

class CurriedFunc:
    def __init__(self, func, args = (), kwArgs = {}, unique = True, minArgs = None):
        self.__func = func
        self.__myArgs = args
        self.__myKwArgs = kwArgs
        self.__unique = unique
        self.__minArgs = minArgs

    def __call__(self, *args, **kwArgs):
        if args or kwArgs:                  # some more args!
            # Allocates data to assign to the next CurriedFunc.
            newArgs = self.__myArgs + args
            newKwArgs = dict.copy(self.__myKwArgs)

            # If unique is True, we don't want repeated keyword arguments.
            if self.__unique and not kwArgs.keys().isdisjoint(newKwArgs):
                raise ValueError("Repeated kw arg while unique = True")

            # Adds/updates keyword arguments.
            newKwArgs.update(kwArgs)

            # Checks whether it's time to evaluate func.
            if self.__minArgs is not None and self.__minArgs <= len(newArgs) + len(newKwArgs):
                return self.__func(*newArgs, **newKwArgs)       # time to evaluate func
            else:
                return CurriedFunc(self.__func, newArgs, newKwArgs, self.__unique, self.__minArgs)
        else:                               # the evaluation was forced
            return self.__func(*self.__myArgs, **self.__myKwArgs)

def cur(f, minArgs = None):
    return CurriedFunc(f, (), {}, True, minArgs)

def curr(f, minArgs = None):
    return CurriedFunc(f, (), {}, False, minArgs)

By the way, while I was recoding genCur into CurriedFunc, I noticed that a “return” was missing! I’ve already updated my previous article.

Now, we just need to add a couple of methods to CurriedFunc:

    def __rrshift__(self, arg):
        return self.__func(*(self.__myArgs + (arg,)), **self.__myKwArgs)      # forces evaluation

    def __sub__(self, other):
        if not isinstance(other, CurriedFunc):
            raise TypeError("Cannot compose a CurriedFunc with another type")

        def compFunc(*args, **kwArgs):
            return other.__func(*(other.__myArgs + (self.__func(*args, **kwArgs),)),
                                **other.__myKwArgs)

        return CurriedFunc(compFunc, self.__myArgs, self.__myKwArgs,
                           self.__unique, self.__minArgs)

As you can see, we added __rrshift__ and not __rshift__. The reason is that instances of CurriedFunc will always appear to the right of the operator ‘>>’. The method __rrshift__ calls function __func with the new positional argument arg without forgetting (stress on the word “forgetting”) the previously added arguments __myArgs and __myKwArgs.

The method __sub__ is not so straight-forward. After a simple sanity check, we define a new function called compFunc and return a curried version of that function. Notice that the functions self.__func and other.__func are composed in the reverse order. To simplify things, let’s say we want to compose two CurriedFunc objects called self and other. We write “self – other” and we expect a CurriedFunc object representing a function which first does what self does and then, on the result, does what other does. Basically, we want “other(self(.))”. From the expression

            return other.__func(*(other.__myArgs + (self.__func(*args, **kwArgs),)),
                                **other.__myKwArgs)

it should be clear that the result of self.__func is passed to other.__func as its last positional argument. Also notice that the arguments of other.__func are frozen (except for the last positional argument coming from self.__func, as we’ve just said).

The full source code

Here’s the entire source code. The new examples start from number 5.

# Coded by Massimiliano Tomassoli, 2012.

class CurriedFunc:
    def __init__(self, func, args = (), kwArgs = {}, unique = True, minArgs = None):
        self.__func = func
        self.__myArgs = args
        self.__myKwArgs = kwArgs
        self.__unique = unique
        self.__minArgs = minArgs

    def __call__(self, *args, **kwArgs):
        if args or kwArgs:                  # some more args!
            # Allocates data to assign to the next CurriedFunc.
            newArgs = self.__myArgs + args
            newKwArgs = dict.copy(self.__myKwArgs)

            # If unique is True, we don't want repeated keyword arguments.
            if self.__unique and not kwArgs.keys().isdisjoint(newKwArgs):
                raise ValueError("Repeated kw arg while unique = True")

            # Adds/updates keyword arguments.
            newKwArgs.update(kwArgs)

            # Checks whether it's time to evaluate func.
            if self.__minArgs is not None and self.__minArgs <= len(newArgs) + len(newKwArgs):
                return self.__func(*newArgs, **newKwArgs)       # time to evaluate func
            else:
                return CurriedFunc(self.__func, newArgs, newKwArgs, self.__unique, self.__minArgs)
        else:                               # the evaluation was forced
            return self.__func(*self.__myArgs, **self.__myKwArgs)

    def __rrshift__(self, arg):
        return self.__func(*(self.__myArgs + (arg,)), **self.__myKwArgs)      # forces evaluation

    def __sub__(self, other):
        if not isinstance(other, CurriedFunc):
            raise TypeError("Cannot compose a CurriedFunc with another type")

        def compFunc(*args, **kwArgs):
            return other.__func(*(other.__myArgs + (self.__func(*args, **kwArgs),)),
                                **other.__myKwArgs)

        return CurriedFunc(compFunc, self.__myArgs, self.__myKwArgs,
                           self.__unique, self.__minArgs)

def cur(f, minArgs = None):
    return CurriedFunc(f, (), {}, True, minArgs)

def curr(f, minArgs = None):
    return CurriedFunc(f, (), {}, False, minArgs)

# Simple Function.
def func(a, b, c, d, e, f, g = 100):
    print(a, b, c, d, e, f, g)

# NOTE: '<====' means "this line prints to the screen".

# Example 1.
f = cur(func)                   # f is a "curried" version of func
c1 = f(1)
c2 = c1(2, d = 4)               # Note that c is still unbound
c3 = c2(3)(f = 6)(e = 5)        # now c = 3
c3()                            # () forces the evaluation              <====
                                #   it prints "1 2 3 4 5 6 100"
c4 = c2(30)(f = 60)(e = 50)     # now c = 30
c4()                            # () forces the evaluation              <====
                                #   it prints "1 2 30 4 50 60 100"

print("\n------\n")

# Example 2.
f = curr(func)                  # f is a "curried" version of func
                                # curr = cur with possibly repeated
                                #   keyword args
c1 = f(1, 2)(3, 4)
c2 = c1(e = 5)(f = 6)(e = 10)() # ops... we repeated 'e' because we     <====
                                #   changed our mind about it!
                                #   again, () forces the evaluation
                                #   it prints "1 2 3 4 10 6 100"

print("\n------\n")

# Example 3.
f = cur(func, 6)        # forces the evaluation after 6 arguments
c1 = f(1, 2, 3)         # num args = 3
c2 = c1(4, f = 6)       # num args = 5
c3 = c2(5)              # num args = 6 ==> evalution                    <====
                        #   it prints "1 2 3 4 5 6 100"
c4 = c2(5, g = -1)      # num args = 7 ==> evaluation                   <====
                        #   we can specify more than 6 arguments, but
                        #   6 are enough to force the evaluation
                        #   it prints "1 2 3 4 5 6 -1"

print("\n------\n")

# Example 4.
def printTree(func, level = None):
    if level is None:
        printTree(cur(func), 0)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(func)

print("\n------\n")

def f2(*args):
    print(", ".join(["%3d"%(x) for x in args]))

def stress(f, n):
    if n: stress(f(n), n - 1)
    else: f()               # enough is enough

stress(cur(f2), 100)

# Pipelining and Function Composition
print("\n--- Pipelining & Composition ---\n")

import sys
from urllib.request import urlopen
from re import findall
from math import sqrt, floor
from functools import reduce

map = cur(map, 2)
filter = cur(filter, 2)
urlopen = cur(urlopen)
findall = cur(findall)
my_print = cur(lambda list : print(*list))
reduce = cur(reduce, 2)

# Example 5

range(0,50) >> filter(lambda i : i%2) >> map(lambda i : i*i) >> my_print

print("---")

# Example 6

compFunc = filter(lambda i : i%2) - map(lambda i : i*i)
range(0,50) >> compFunc >> my_print

print("---")

# Example 7

# Tells whether x is not a proper multiple of n.
notPropMult = cur(lambda n, x : x <= n or x % n, 2)

def findPrimes(upTo):
    if (upTo <= 5): return [2, 3, 5]
    filterAll = (findPrimes(floor(sqrt(upTo)))
                 >> map(lambda x : filter(notPropMult(x)))
                 >> reduce(lambda f, g : f - g))
    return list(range(2, upTo + 1)) >> filterAll

findPrimes(1000) >> my_print

print("---")

# Example 8
# Finds the approximate number of hrefs in a web page.

def do(proc, arg):
   proc()
   return arg
do = cur(do)

cprint = cur(print)

("http://python.org"
 >> do(cprint("The page http://python.org has about... ", end = ''))
 >> do(sys.stdout.flush)
 >> urlopen
 >> cur(lambda x : x.read())
 >> findall(b"href=\"")
 >> cur(len)
 >> cur("{} hrefs.".format)
 >> cprint)

That’s all!

Please feel free to let me know what you think by leaving a comment!

Posted in Python | Tagged: , , , , , , | 1 Comment »

Currying in Python

Posted by mtomassoli on March 18, 2012

What is Currying?

Currying is like a kind of incremental binding of function arguments. Let’s define a simple function which takes 5 arguments:

def f(a, b, c, d, e):
    print(a, b, c, d, e)

In a language where currying is supported, f is a function which takes one argument (a) and returns a function which takes 4 arguments. This means that f(5) is the following function:

def g(b, c, d, e):
    f(5, b, c, d, e)

We could emulate this behavior the following way:

def f(a):
    def g(b, c, d, e):
        print(a, b, c, d, e)
    return g

f1 = f(1)
f1(2,3,4,5)

Now, f(1) returns the function g(b, c, d, e) which, for all b, c, d and e, behaves exactly like f(1, b, c, d, e). Since g is a function, it should support currying as well. Then what we really want is this:

def f(a):
    def g(b):
        def h(c, d, e):
            print(a, b, c, d, e)
        return h
    return g

f(1)(2)(3,4,5)

So what is f? It’s a function which takes an argument a and returns another function g which “remembers” that argument. The same way, g is a function which takes an argument b and returns another function h which “remembers” that argument. Basically, h is the function obtained from f by binding the first two arguments to a and b, respectively. Note that the last line of the code above is equivalent to this:

f1 = f(1)
f12 = f1(2)
f12(3,4,5)

I know what you’re thinking: we aren’t done yet! Here’s the final version in all its glory:

def f(a):
    def g(b):
        def h(c):
            def i(d):
                def j(e):
                    print(a, b, c, d, e)
                return j
            return i
        return h
    return g

f(1)(2)(3)(4)(5)

That’s currying for you!

But what about f(1, 2, 3, 4, 5)? Ops… f doesn’t take 5 arguments anymore :(

Well, there are some functional languages (Haskell, for instance) where function application is indicated by juxtaposition. For instance, f(a, b, c, d, e) would be expressed like this:

f a b c d e

Does that mean that f is a function which takes 5 arguments? Nope. Like in the example above, f is a function which takes an argument and returns a function which takes an argument, and so on… Let’s see it this way: the space (‘ ‘) is an operator that applies the right operand to the left operand (a function). The space is left-associative, meaning that the expression above is grouped this way:

((((f a) b) c) d) e

If you don’t like the operator space, let’s use some other operator:

((((f<-a)<-b)<-c)<-d)<-e

The operator ‘<-‘ passes its right operand to its left operand, which is a function. Indeed, these are functions:

  • f
  • f<-a
  • (f<-a)<-b
  • ((f<-a)<-b)<-c
  • (((f<-a)<-b)<-c)<-d

We can remove the brackets, of course:

f<-a<-b<-c<-d<-e

What am I trying to say? Simply that classic currying is better left to languages which use juxtaposition to pass arguments to functions.

Why currying?

Simple enough: it’s an easy way to get specialized functions from more general functions. Here are a few simple examples:

  • times3 = myMul(3)
    • Usage:
      listTimes3 = map(times3, list)          # evaluates myMul(3, x) for all x in list
  • printErrMsg = printMsg(“Error”)
    printWrnMsg = printMsg(“Warning”)
    • Usage:
      printErrMsg(“Divide by Zero!”)
      printWrnMsg(“You should save your document.”)
  • lexicSort = mySort(lexicCmp)
    • Usage:
      lexicSort(listOfStrings)

Currying VS partial application

Partial application is partial binding:

from functools import partial

def f(a, b, c, d):
    print(a, b, c, d)

g = partial(f, 1, 2, 3)
g(4)

Looks familiar?

Is that currying? Not quite. It’s more like manual currying:

from functools import partial

def f(a, b, c, d):
    print(a, b, c, d)

g = partial(partial(partial(f, 1), 2), 3)
g(4)

Is currying right for us?

We said that, in Haskell, we can write

f a b c d e

while, in Python, we would have to write

f(1)(2)(3)(4)(5)

I don’t know about you, but I prefer this one:

f(1, 2, 3, 4, 5)

We should also keep in mind that Python has keyword argument, so this should also be allowed:

f(1, 2, e = 5)

Is that real currying? It depends on whom you ask. But if you like it and find it useful, who cares?

Incremental binding

So let’s call it incremental binding because we can bind arguments in any order we like and how many arguments we feel like at once. Let’s pretend we have implemented a magical function called cur. Here’s a simple function we can use to test cur:

# Simple Function.
def func(a, b, c, d, e, f, g = 100):
    print(a, b, c, d, e, f, g)

Now let’s create a curried version of that function:

# f is the "curried" version of func.
f = cur(func)

Now we would like to be able to write the following (and get the expected result, of course!):

c1 = f(1)
c2 = c1(2, d = 4)               # Note that c is still unbound
c3 = c2(3)(f = 6)(e = 5)        # now c = 3
c3()                            # () forces the evaluation              <====
                                #   it prints "1 2 3 4 5 6 100"
c4 = c2(30)(f = 60)(e = 50)     # now c = 30
c4()                            # () forces the evaluation              <====
                                #   it prints "1 2 30 4 50 60 100"

The lines that print to the screen are marked with the double arrow ‘<====’.

The first line binds the argument a to 1 and save this less-free (i.e. more-bounded) function in c1. Now we have two currying-able functions: f and c1.

Here comes the interesting part: the second line binds b to 4 and d to 4, while c remains unbounded. Basically, at each step we can bind some positional arguments and some keyword arguments. If we have bound 2 arguments so far, then the 3rd argument will be the next positional argument to be bound. Indeed, line 3 binds c to 3 and then binds the two keyword arguments f and e (the order doesn’t matter, as you can see). Note that I talk of positional arguments and keyword arguments, but here positional and keyword don’t describe the arguments but the way the arguments are bound.

Note how line 4 forces the evaluation.

But… shouldn’t the evaluation be automatic? I would say no and here’s why.

As you may know, a function is called Referentially Transparent if its behavior depends on its arguments alone. This means that a function without arguments is a constant: it behaves the same way every time. Since functions don’t have side effects in pure functional languages, “to behave” is not the most appropriate verb: I should have said that a constant function always returns the same value. Now let’s look at c3 in the previous example. Is that a function? In a pure functional language the answer would be no: it’s a constant. But that means we can’t evaluate a function more than once. As you can see, func has side effects: it prints to the screen! Why shouldn’t we be able to call c3 more than once? So, manual evaluation is really a feature! Learning from other languages is a good think but trying to emulate the same exact behavior at all cost is a mistake.

The last two lines in the example above show another binding starting from c2.

Here’s our second example:

f = curr(func)                  # f is a "curried" version of func
                                # curr = cur with possibly repeated
                                #   keyword args
c1 = f(1, 2)(3, 4)
c2 = c1(e = 5)(f = 6)(e = 10)() # ops... we repeated 'e' because we     <====
                                #   changed our mind about it!
                                #   again, () forces the evaluation
                                #   it prints "1 2 3 4 10 6 100"

What is curr? It’s a variant of cur which lets you rebind arguments by keyword. In the example above, we first bind e to 5 and then rebind it to 10. Note that, even if that’s a one-line statement, we’re binding in 3 different steps. Finally, we forces the evaluation with ‘()’.

But I can feel some of you still want automatic evaluation (Why???). Ok, me too. So how about this:

f = cur(func, 6)        # forces the evaluation after 6 arguments
c1 = f(1, 2, 3)         # num args = 3
c2 = c1(4, f = 6)       # num args = 5
c3 = c2(5)              # num args = 6 ==> evalution                    <====
                        #   it prints "1 2 3 4 5 6 100"
c4 = c2(5, g = -1)      # num args = 7 ==> evaluation                   <====
                        #   we can specify more than 6 arguments, but
                        #   6 are enough to force the evaluation
                        #   it prints "1 2 3 4 5 6 -1"

You see how nicely default-valued arguments are handled? In the example above, 6 arguments are enough to force the evaluation, but we can also define all 7 of them (if we’re quick enough!). If we feel like it, we could even use some reflection to find out the number of non default-valued arguments (or at least I think so: I’ve been programming in Python for just 3 days and I haven’t looked at reflection yet). I don’t feel like it, but it’s simple to write a little wrapper around cur and curr.

Are we sure that our (still imaginary) cur works like expected? Let’s try something more convoluted:

def printTree(func, level = -1):
    if level == -1:
        printTree(cur(func), level + 1)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(f)

What does it do? It prints the first 64 non negative integers in base two. Here’s the exact output:

0 0 0 0 0 0
0 0 0 0 0 1
0 0 0 0 1 0
0 0 0 0 1 1
0 0 0 1 0 0
0 0 0 1 0 1
0 0 0 1 1 0
0 0 0 1 1 1
0 0 1 0 0 0
0 0 1 0 0 1
0 0 1 0 1 0
0 0 1 0 1 1
0 0 1 1 0 0
0 0 1 1 0 1
0 0 1 1 1 0
0 0 1 1 1 1
0 1 0 0 0 0
0 1 0 0 0 1
0 1 0 0 1 0
0 1 0 0 1 1
0 1 0 1 0 0
0 1 0 1 0 1
0 1 0 1 1 0
0 1 0 1 1 1
0 1 1 0 0 0
0 1 1 0 0 1
0 1 1 0 1 0
0 1 1 0 1 1
0 1 1 1 0 0
0 1 1 1 0 1
0 1 1 1 1 0
0 1 1 1 1 1
1 0 0 0 0 0
1 0 0 0 0 1
1 0 0 0 1 0
1 0 0 0 1 1
1 0 0 1 0 0
1 0 0 1 0 1
1 0 0 1 1 0
1 0 0 1 1 1
1 0 1 0 0 0
1 0 1 0 0 1
1 0 1 0 1 0
1 0 1 0 1 1
1 0 1 1 0 0
1 0 1 1 0 1
1 0 1 1 1 0
1 0 1 1 1 1
1 1 0 0 0 0
1 1 0 0 0 1
1 1 0 0 1 0
1 1 0 0 1 1
1 1 0 1 0 0
1 1 0 1 0 1
1 1 0 1 1 0
1 1 0 1 1 1
1 1 1 0 0 0
1 1 1 0 0 1
1 1 1 0 1 0
1 1 1 0 1 1
1 1 1 1 0 0
1 1 1 1 0 1
1 1 1 1 1 0
1 1 1 1 1 1

Was it really necessary to cut and paste all that? Probably not… I could’ve written it by hand.

Here’s the code again (that’s what long unneeded listings do to your articles):

def printTree(func, level = -1):
    if level == -1:
        printTree(cur(func), level + 1)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(func)

Our printTree is a function which takes a function func and an optional level which defaults to -1. There are 3 cases:

  • level is –1:
    • we didn’t call ourselves, so someone has just called us and passed to us (we hope so) a “normal” function;
    • we call ourselves recursively, but this time the way we like: with a curried version of func and with level 0. That’s better.
  • level is 6:
    • we called ourselves recursively many times and we’re finally at level 6;
    • our (very own private) func should have the first 6 arguments already bounded;
    • we bind the last argument gto something that won’t show up on the screen: the empty string(if you’re wondering why I disregard the last argument this way, you should try to include a monotonous 128-line output in one of your articles!);
    • we evaluate func and print a number on the screen!
  • level is between 0 and 5 (limits included):
    • our func is a function which have the first level arguments bound and the remaining arguments unbound;
    • we can’t evalute func: we need more arguments;
    • we control the argument x in position level;
    • we bound x to 0 and let another instance of printTree handle the remaining arguments;
    • we now bound x to 1 and call yet another instance of printTree.

If you hate recursion, you can very well skip this example and go to the last one which is… ops… recursive as well:

def f2(*args):
    print(", ".join(["%3d"%(x) for x in args]))

def stress(f, n):
    if n: stress(f(n), n - 1)
    else: f()               # enough is enough

stress(cur(f2), 100)

It prints this:

100,  99,  98,  97,  96,  95,  94,  93,  92,  91,  90,  89,  88,  87,  86,  85,
 84,  83,  82,  81,  80,  79,  78,  77,  76,  75,  74,  73,  72,  71,  70,  69,
 68,  67,  66,  65,  64,  63,  62,  61,  60,  59,  58,  57,  56,  55,  54,  53,
 52,  51,  50,  49,  48,  47,  46,  45,  44,  43,  42,  41,  40,  39,  38,  37,
 36,  35,  34,  33,  32,  31,  30,  29,  28,  27,  26,  25,  24,  23,  22,  21,
 20,  19,  18,  17,  16,  15,  14,  13,  12,  11,  10,   9,   8,   7,   6,   5,
  4,   3,   2,   1

Why did I use recursion? Because I was so caught up in implementing cur that I forgot Python has loop constructs :)

Beginners shouldn’t have too much of a problem with this example and I’m getting a little bored myself, so let’s proceed!

Wonderful, but how do we implement cur???

Let’s proceed a step at a time.

We might implement a function which remembers the arguments it has received so far:

myArgs = []
def f(*args):
    global myArgs           # we will refer to the global var
    if len(args):           # some more args!
        myArgs += args
    else:                   # time to evaluate...
        print(*myArgs)

f(1,2)
f(3,4)
f()

The last line prints “1 2 3 4”.

Our function asks to receive an arbitrary number of arguments packed in the local array args and then there are two cases:

  1. if len(args) is non-zero, we have received some arguments from the caller so we append them to our global array myArgs;
  2. if len(args) is zero, the caller wants to force the evaluation so we act as if we were called with all the arguments in myArgs from the start.

Any objections? Like

  1. Why aren’t we writingg = f(1,2)but justf(1,2)?
  2. What aboutf(1,2)(3,4)()?
  3. What if we want two different bindings of f at the same time?

Objection 2 is easy to deal with:

myArgs = []
def f(*args):
    global myArgs           # we will refer to the global var
    if len(args):           # some more args!
        myArgs += args
        return f
    else:                   # time to evaluate...
        print(*myArgs)

f(1, 2)(3, 4)()

We just needed to add return f so now f(1, 2) returns f and so on…

What about objection 3? That is strictly related the first objection: we don’t need assignments because we are allowing a single global binding per function and that defeats the purpose of having currying. The solution is easy: let’s give each partially bounded f its private myArgs.

We can do that by wrapping myArgs and f inside another function so that myArgs is not shared anymore:

def g(*args):
    myArgs = []
    def f(*args):
        nonlocal myArgs         # now it's non-local: it isn't global
        if len(args):           # some more args!
            myArgs += args
            return f
        else:                   # time to evaluate...
            print(*myArgs)
    return f(*args)

g1 = g(1, 2)
g2 = g(10, 20)
g1(3, 4)()
g2(30, 40)()

The last two lines print “1 2 3 4” and “10 20 30 40”, respectively. We did it!

Let’s try something a little different:

a = g(1)
b = g(2)
a1 = a(1, 1)()
a2 = a(2, 2)()

It prints “1 1 1” and… ops… “1 1 1 2 2”. That’s right: functions a and b behave exactly like our old f with its global myArgs. It’s true that myArgs isn’t global anymore, but while a and b receive two different myArgs, every successive use of a will use that same array. So the last line adds 2, 2 to a’s myArgs which already contains 1, 1, 1. The second line is just there for symmetry :)

How do we solve this problem? Hmm… more recursion!!! … Nope… that wouldn’t work. How about less recursion?

Let’s look at the code again: “a = g(1)” returns an f-type (so to speak) function so the next “a(1, 1)” doesn’t work correctly (we’ve just postponed the problem we were facing from the start with f). And if a itself returned “g(1, 1, 1)”? It’d be as if the caller had performed a single binding by calling g. History would be forgotten and totally irrelevant. Every binding would look like the first and only binding!

Before proceeding, we should get rid of another flaw. Look at this example:

a = g(1)
b = g(2)
a(1, 1)
a(2, 2)()

This prints “1 1 1 2 2”. Our old f, here referred to by a, still likes to remember things it shouldn’t. We can solve both problems at the same time:

def g(*args):
    myArgs = args
    def f(*args):
        nonlocal myArgs
        if len(args):           # some more args!
            return g(*(myArgs + args))
        else:                   # time to evaluate...
            print(*myArgs)
    return f

a = g(1)
b = g(2)
a(1, 1)
a(2, 2)()

Now, g gives each f its arguments and no f can change them in any way. When f receives more arguments (args), it creates and returns a new version of f which includes all the arguments seen so far. As you can see, the new f is created by calling g one more time at line 6.

We can do even better:

def g(*myArgs):
    def f(*args):
        if len(args):           # some more args!
            return g(*(myArgs + args))
        else:                   # time to evaluate...
            print(*myArgs)
    return f

The argument args was already a local variable of g (in a sense), so why create another local variable myArgs? Also, we don’t need nonlocal because we won’t modify myArgs: we’ll just read it.

We still miss a step: I hope you won’t code any single function by hand like that! We should implement cur as a curried version of a general function. That’s easy enough:

def cur(func):
    def g(*myArgs):
        def f(*args):
            if len(args):           # some more args!
                return g(*(myArgs + args))
            else:                   # time to evaluate...
                func(*myArgs)
        return f
    return g

def f(a, b, c, d, e):
    print(a, b, c, d, e)
cf = cur(f)

a = cf(1)
b = cf(2)
a(1, 1)
a(2, 2)(3, 4)()

Now we can include keyword arguments:

def cur(func):
    def g(*myArgs, **myKwArgs):
        def f(*args, **kwArgs):
            if len(args) or len(kwArgs):    # some more args!
                newArgs = myArgs + args
                newKwArgs = dict.copy(myKwArgs);
                newKwArgs.update(kwArgs)
                return g(*newArgs, **newKwArgs)
            else:                           # time to evaluate...
                func(*myArgs, **myKwArgs)
        return f
    return g

def f(a, b, c, d, e):
    print(a, b, c, d, e)
cf = cur(f)

a = cf(1, e = 10)
b = cf(2)
a(1, 1)
a(2, d = 9)(3)()

Final Version

Now you should be able to understand the final version, which has some additional features. Here’s the entire code (examples included):

# Coded by Massimiliano Tomassoli, 2012.
#
# - Thanks to b49P23TIvg for suggesting that I should use a set operation
#     instead of repeated membership tests.
# - Thanks to Ian Kelly for pointing out that
#     - "minArgs = None" is better than "minArgs = -1",
#     - "if args" is better than "if len(args)", and
#     - I should use "isdisjoint".
#
def genCur(func, unique = True, minArgs = None):
    """ Generates a 'curried' version of a function. """
    def g(*myArgs, **myKwArgs):
        def f(*args, **kwArgs):
            if args or kwArgs:                  # some more args!
                # Allocates data to assign to the next 'f'.
                newArgs = myArgs + args
                newKwArgs = dict.copy(myKwArgs)

                # If unique is True, we don't want repeated keyword arguments.
                if unique and not kwArgs.keys().isdisjoint(newKwArgs):
                    raise ValueError("Repeated kw arg while unique = True")

                # Adds/updates keyword arguments.
                newKwArgs.update(kwArgs)

                # Checks whether it's time to evaluate func.
                if minArgs is not None and minArgs <= len(newArgs) + len(newKwArgs):
                    return func(*newArgs, **newKwArgs)  # time to evaluate func
                else:
                    return g(*newArgs, **newKwArgs)     # returns a new 'f'
            else:                               # the evaluation was forced
                return func(*myArgs, **myKwArgs)
        return f
    return g

def cur(f, minArgs = None):
    return genCur(f, True, minArgs)

def curr(f, minArgs = None):
    return genCur(f, False, minArgs)

# Simple Function.
def func(a, b, c, d, e, f, g = 100):
    print(a, b, c, d, e, f, g)

# NOTE: '<====' means "this line prints to the screen".

# Example 1.
f = cur(func)                   # f is a "curried" version of func
c1 = f(1)
c2 = c1(2, d = 4)               # Note that c is still unbound
c3 = c2(3)(f = 6)(e = 5)        # now c = 3
c3()                            # () forces the evaluation              <====
                                #   it prints "1 2 3 4 5 6 100"
c4 = c2(30)(f = 60)(e = 50)     # now c = 30
c4()                            # () forces the evaluation              <====
                                #   it prints "1 2 30 4 50 60 100"

print("\n------\n")

# Example 2.
f = curr(func)                  # f is a "curried" version of func
                                # curr = cur with possibly repeated
                                #   keyword args
c1 = f(1, 2)(3, 4)
c2 = c1(e = 5)(f = 6)(e = 10)() # ops... we repeated 'e' because we     <====
                                #   changed our mind about it!
                                #   again, () forces the evaluation
                                #   it prints "1 2 3 4 10 6 100"

print("\n------\n")

# Example 3.
f = cur(func, 6)        # forces the evaluation after 6 arguments
c1 = f(1, 2, 3)         # num args = 3
c2 = c1(4, f = 6)       # num args = 5
c3 = c2(5)              # num args = 6 ==> evalution                    <====
                        #   it prints "1 2 3 4 5 6 100"
c4 = c2(5, g = -1)      # num args = 7 ==> evaluation                   <====
                        #   we can specify more than 6 arguments, but
                        #   6 are enough to force the evaluation
                        #   it prints "1 2 3 4 5 6 -1"

print("\n------\n")

# Example 4.
def printTree(func, level = None):
    if level is None:
        printTree(cur(func), 0)
    elif level == 6:
        func(g = '')()      # or just func('')()
    else:
        printTree(func(0), level + 1)
        printTree(func(1), level + 1)

printTree(func)

print("\n------\n")

def f2(*args):
    print(", ".join(["%3d"%(x) for x in args]))

def stress(f, n):
    if n: stress(f(n), n - 1)
    else: f()               # enough is enough

stress(cur(f2), 100)

That’s all!

Please feel free to let me know what you think by leaving a comment!

Posted in Python | Tagged: , , , , , , , | 5 Comments »

Boot your OS from CD/DVD

Posted by mtomassoli on February 27, 2009

 

Warning

This article is written in Bad English (BE), the most widespread language in the Web.

 

Source code

You won’t find any source code here. I don’t believe that providing source code is always a good thing. My reasons are simple:

  • if you really understand something you’ll be able to implement it by yourself in your language of choice;
  • usually, you just end up copying the code because you are too lazy to reread the explanation provided or too lazy to ask the author for more information;
  • I don’t like learning from code because it contains way too many irrelevant details.

That’s right. Sometimes having too many details is worse than missing some of them. The problem is that who writes the code need to know much more than what can be seen by the code itself. On the other hand, the code is the result of many totally arbitrary choices usually not explicitly marked as such. Over-commented code would probably do, but I don’t like over-commented code.

Bottom line: no source code here.

 

Booting from CD

I guess you already know how to boot from a hard disk or a floppy disk. Booting from a CD is just a bit more complex, but nothing to worry about.

Sooner or later, you’ll have to read this paper:
specscdrom.pdf.
What is El Torito? You’ll find some information here:
http://en.wikipedia.org/wiki/El_Torito_(CD-ROM_standard).

I am not an expert in these things. I just wanted to make my OS bootable from CD.

 

ISO 9660 & .iso

ISO 9660 is a standard that defines a file system for CD-ROM media. CDs contain less information than hard disks and can’t be modified as easily, so ISO 9660 is by far simpler than, say, NTFS.
User writable CD-R and CD-RW use the UDF format which is more complex than ISO 9660, though.

While ISO 9660 is (the specification of) a file system, an .iso file is just a sector by sector copy of a CD or DVD. It has NOTHING to do with ISO 9660!

An .iso image contains the so-called cooked 2048-byte sectors of a CD or DVD. They’re cooked and not raw because control data is missing.
Have a look here:
http://en.wikipedia.org/wiki/ISO_9660#CD-ROM_Specifications.
As you can see, in cooked 2048-byte sectors synchronization information and error correction and detection codes are missing. They’re normally automatically created for you, so we don’t need to worry about all that.
We just need to write user-data as a sequence of 2048-byte sectors. If we want to create a 30-sector CD/DVD we’ll need to create an .iso file of 30*2048 bytes. That’s all.

 

El Torito & ISO 9660

What does it mean that El Torito “format” is an extension to the ISO 9660 format?
It simply means that it’s compliant with ISO 9660, i.e. a CD may be bootable and still follow the ISO 9660 spec., which is a good thing, of course.

But that also means that we don’t have to follow the ISO 9660 format at all! When all the structures that El Torito format requires are present in the .iso, we can add more data by simply writing it in arbitrarily chosen sectors.

For instance, we could write the content of three modules of our OS in the following groups of sectors: 30…39, 40…63 and 64…100. Then our OS could read these three files directly because he would know where they are located in advance.
If this weren’t satisfying, we could even devise a simple file system.

 

Emulation

El Torito spec. talks about floppy disk images, hard disk images and emulation. Why? When you boot from a floppy disk, the BIOS loads your boot code found in sector 0 (zero) at the physical address 7C00h.
Now your code can access the content of the floppy disk by calling the 13h BIOS services. The boot from hard disk is analogous.
When the BIOS jumps to your code, DL contains the ID of the boot device. Floppy disk drives start from 0 and hard disks from 80h. CD-ROM drives should start from A0h.

What would happen if an old program were booted from a CD-ROM?
First of all, he could complain about DL being A0h and crash. Secondly, how would it be supposed to read from a CD? The old Int 13h doesn’t work.
So BIOSes which implement the El Torito extension, must virtualize the access to the CD and pretend it’s a normal floppy disk or hard disk! Such BIOSes shall set DL to 0 or 80h (or similar) and transparently extend the old int 13h interface in such a way that the old program thinks it’s reading from the media it was supposed to be booted from.

The important thing to understand is that the compatibility problem that is being solved is software related, not hardware or firmware related. If your BIOS doesn’t know anything about the El Torito extension then there’s nothing you can do about it.

 

Reinventing the Wheel

Many tutorials suggest that you should create a bootable floppy disk and then use some software to convert that to a bootable .iso.
Well, I think that’s not the right way to proceed. Sometimes, easy means “the only way I know how to make it work”. That doesn’t mean it’s easy. That only means that you didn’t find information about other methods.
Reinventing the wheel is bad, you will often be told. I think that reinventing the wheel is a good thing as long as you can tell good and bad wheels apart.
If you create something horrible you have to be aware of that.

Moreover, you really shouldn’t tell an OS developer that reinventing the wheel is bad. He’s developing an OS!!!

Reinventing the wheel is also extremely didactical, in fact you can’t make by yourself what you don’t understand. Secondly, if you do something by yourself you’re free to customize it and you’ll be more independent, especially when something doesn’t work and you have to understand what’s wrong with it.

 

Native Mode & int 13h extensions

With my project I went directly for the booting in Native Mode. When you boot from a CD/DVD in Native Mode you’ll be able to make the BIOS load up to 32 MB of code for you. The best thing to do, however, is to use the BIOS only initially. Afterwards you should write your own code.
AFAIK, int 13h can be called only in real mode, so you’ll have to run a v8086 task from protected mode or set up some kind of virtualization by yourself (I’m assuming your OS will run in protected mode).
Since I’m targeting recent platforms, I’ll assume that every BIOS supports int 13h extensions (ah>40h). See http://www.t10.org/t13/docs2004/d1572r3-EDD3.pdf.

From real mode, you can read 2048-byte sectors through the int 13h services 41h-48h. The good news is that you’ll be using LBA addressing mode, i.e. the sectors will be linearly numbered starting from 0. Keep in mind that we’re dealing with two types of sectors:

  • 512-byte sectors and
  • 2048-byte sectors

When you tell the BIOS how many sectors he should load from the CD, you specifies the number of 512-byte sectors. Since this number, as we shall see, is a WORD, you can ask for

  512*65536 bytes = 2^(9+16) bytes = 2^5*2^20 bytes = 32MB

at most. On the other hand, when you use the functions 41h-48h in your code, you’ll be referring to 2048-byte sectors.

 

Let’s get started

First of all, we don’t need any specific software. Sometimes people use very powerful software to do very simple things and the result isn’t even completely satisfying. Well, here we’ll create the .iso all by ourselves. What we need is just a way to tell the BIOS to

  1. read N 512-byte sectors
  2. starting from the 2048-byte sector S
  3. and copy them to the physical address P in RAM.

But El Torito spec. aims at being ISO 9660 compliant (remember?) so we’ll have to do much more than that. We’ll need to write the following structures:

  1. Boot Record Volume Descriptor (BVD)
  2. Boot Catalog (BC)

where the BC consists of

  1. Validation Entry (VE)
  2. Initial/Default Entry
  3. Section Header
  4. Section Entry

There may be many section headers and many sections. You’ll have to refer to the documentation for the details. I’ll just guide you through the creation of a minimalistic .iso. The idea is this:

  1. we create a BVD at sector 17 (i.e. the 18-th sector)
  2. we create a BC at sector 18
  3. we write our code (boot image) to sector 19-20-…

The BVD shall point to the BC which shall point to our boot image.

 

BVD

A BVD has the following format:

struct BootVolDesc
{
    BYTE bootRecInd;
    BYTE specId[5];
    BYTE descVer;
    BYTE specStr[32];
    BYTE reserved[32];
    DWORD bootCatSec;       // absolute sector number when
                            // the boot catalog starts
    BYTE reserved2[1973];
};

The documentation is clear:

  1. bootRecInd must be set to 0
  2. specId to "CD001"
  3. descVer to 1
  4. specStr to "EL TORITO SPECIFICATION"
  5. reserved must be filled with 0
  6. reserved2 must be filled with 0

And, finally, we’ll set bootCatSec to 18 because there is where we’ll put our BC. Note that the BVD fills the entire 2048-byte sector 17, then no padding is needed.

 

Validation Entry

The BC starts with the Validation Entry:

struct ValidationEntry
{
    BYTE headerId;
    BYTE platformId;    // 0 = 80x86, 1 = Power PC, 2 = Mac
    WORD reserved;
    BYTE devName[24];   // developer or manufacturer of the ISO
    WORD checkSum;      // such that the sum of all the words
                        // gives 0x0000
    WORD magicWord;
};

We do as the documentation says:

  1. we set headerId to 1
  2. platformId to 0 (it shouldn’t be too hard to figure out why)
  3. reserved to 0
  4. magicWord to 0xAA55

Please note that, in little-endian, bytes are written to memory from the least to the most significant one, therefore the word 0xAA55 is written byte by byte as

  0×55 0xAA.

We don’t care in which order single bits are read/written because this detail is not architectural (and not exposed by the operations).

Finally, we write some 24-byte-long ASCII string. This is my (provisional) string, but don’t you dare use it for yourself! :-)

  "Virtual Debugger"

Now you have to choose checkSum in such a way that the sum of all the WORDs in this validation entry is zero. You can proceed as follows:

  1. you set checkSum to 0
  2. you read the entire structure as if it were an array of WORDs
  3. you compute the sum SUM of all the WORDs
  4. you set checkSum to -SUM

Remember that, in a two’s complement representation of N bits, -X is nothing more than a N-bit number such that (-X)+X = 2^N.

For instance, with WORDs, we have

  FFFF + 0001 = 10000,

where 10000 = is 2^16. Since the most significant bit is lost, we are left with 0000.

 

Initial/Default Entry

Now we need to write the I/D Entry right after the VE. The format is as follows:

struct SectionEntry
{
    BYTE bootable;          // 0x88 = bootable image present,
                            // 0 = non-bootable image present
    BYTE bootMediaType;     // 0 = no emulation,
                            // 1 = 1.2MB diskette,
                            // 2 = 1.44MB diskette,
                            // 3 = 2.88MB diskette,
                            // 4 = hard drive,
                            // 5-0xff = reserved
    WORD entryCodeSeg;      // usually 0x7c0
    BYTE systemType;
    BYTE reserved;
    WORD numSecToLoad;      // number of 512-byte sectors to load
                            // (usually 1 in emulation mode)
    DWORD startingSec;      // absolute address of the first sector
                            // of the image to load

    // The following data must be 0 if this is the "Initial/Default
// Entry". BYTE selCriteria; // 0 = no selection criteria, // 1 = language and revision information // (IBM format), // 2-0xff = reserved BYTE selCriteria2[19]; // selection criteria };

Here we go:

  1. we set bootable to 0×88 (obviously)
  2. *I* set bootMediaType to 0 (no emulation)
  3. *I* set entryCodeSeg to 0x7c0 (It’s an old friend)
  4. we set systemType to 0
  5. reserved to 0
  6. selCriteria to 0
  7. we fill selCriteria2 with 0

In my case startingSec is 19 because I decided to put my code (boot image) in the sector 19 (and following sectors…). numSecToLoad must be set to the number of 512-byte sectors you want the BIOS to load for you. You might or might not pad the image with 0 in such a way that its length is a multiple of 512, but I don’t think it’s required.

If your boot image is x bytes long, you should set numSecToLoad to

  ceiling(x/512).

You can compute that value as

  (x+511) div 512,

where div is the integer division.

Let k be a non-negative integer and r a positive integer less than 512.

If x = 512*k, then (x+511) div 512 = k.

If x = 512*k + r, then (x+511) div 512 = k+1.

That’s exactly what we wanted.

 

Extra Section

I think there must exist at least one extra Section Entry in the .iso. The documentation is not particularly clear about it.

Here’s the structure for the Section Header:

struct SectionHeader
{
    BYTE id;               // 0x90 = other sections will follow this one,
                           // 0x91 = this is the final section
    BYTE platformId;       // 0 = 80x86, 1 = Power PC, 2 = Mac
    WORD numSecEntries;    // number of section entries
    BYTE sectionName[28];
};

Note that one documentation reports 0×90 and 0×91 while the other 90 and 91. I’m currently using 0×90 and 0×91 and all seems to work fine. We don’t really need another section because our boot image is already pointed by the I/D Entry (which is itself a section). For this reason, we set:

  1. id to 0×91, to indicate that this is the LAST section.

    This is also what makes me think that we need at least one section. If the only way of saying that there are no other sections (besides the sections introduced by this header, of course) is by setting id to 0×91 in the last header and each header should be followed by at least one section, how can we do it without having at least one section?
  2. platformId to 0
  3. numSecEntries to 1
  4. sectionName to an arbitrary string padded with 0.

Now we have to write the actual section. Its structure is identical to that of the I/D Entry, because, as I already said, that’s itself a section. For more or less subtle differences please refer to the documentation. We should initialize this last section as we did before, but with some exception:

  1. we set numSecToLoad to 1
  2. we set bootable to 0

Here I’m being a little defensive, i.e. I prefer to do something potentially superfluous instead of missing something mandatory. I was lucky, in fact my .iso worked at the first try.

For this reason, this section has valid entryCodeSeg and startingSec as well (they are the same we used before). At this point, the current sector, i.e. sector 18, contains:

  1. Validation Entry
  2. Section Entry
  3. Section Header
  4. Section Entry

What you have to do now is pad the current (2048-byte) sector with 0.

 

Boot Image

Finally, we write our boot image and pad the .iso file with 0 so that its size is a multiple of 2048.

Please remember that sector 17 is the 18-th sector (they start from 0) so you may want to start by writing 2048*17 bytes to your .iso and then write all the structures we talked about.

And please also remember that you can’t read from the CD with the old int 13h functions (i.3. int 13h, ah=2). Have a look at the documentation.

By the way, you don’t need that

db 510 - ($ - $$) dup (0) dw 0AA55h

or

times 510 - ($ - $$) db 0 dw 0AA55h

anymore.

That’s all. Happy coding!

Posted in Boot | Leave a Comment »

 
Follow

Get every new post delivered to your Inbox.