tco in python via bytecode manipulation
DESCRIPTION
TCO in Python via bytecode manipulation.TRANSCRIPT
Optimizing tail recursion in Python using bytecode
manipulations.
Allison KapturPaul TagliamonteLiuda Nikolaeva(all errors are my own)
Problem:
Python has a limit on recursion depth:
def factorial(n, accum):
if n <= 1:
return accum
else:
return factorial(n-1, accum*n)
>>> tail-factorial(1000)
RuntimeError: maximum recursion depth exceeded
Challenge:
• Optimize recursive function calls so that they don’t create new frames, thus avoiding stack overflow.
• What we want: eliminate the recursive call; instead, reset the variables and jump to the beginning of the function.
Problem:
How do you change the insides of a function?
Bytecode!
Solution:
(obviously)
Quick intro to bytecode.def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> f.__code__.co_code
'|\x00\x00d\x01\x00k\x01\x00r\x10\x00|\x01\x00St\x00\x00|\x00\x00d\x01\x00\x18|\x01\x00|\x00\x00\x14\x83\x02\x00Sd\x00\x00S‘
>>> print [ord(b) for b in f.__code__.co_code]
[124, 0, 0, 100, 1, 0, 107, 1, 0, 114, 16, 0, 124, 1, 0, 83, 116, 0, 0, 124, 0, 0, 100, 1, 0, 24, 124, 1, 0, 124, 0, 0, 20, 131, 2, 0, 83, 100, 0, 0, 83]
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
def f(n, accum):
if n <= 1:
return accum
else:
return f(n-1, accum*n)
>>> import dis>>> dis.dis(f)2 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)6 COMPARE_OP 1 (<=)9 POP_JUMP_IF_FALSE 16
3 12 LOAD_FAST 1 (accum)15 RETURN_VALUE
5 >> 16 LOAD_GLOBAL 0 (f)19 LOAD_FAST 0 (n)22 LOAD_CONST 1 (1)25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum)29 LOAD_FAST 0 (n)32 BINARY_MULTIPLY 33 CALL_FUNCTION 236 RETURN_VALUE 37 LOAD_CONST 0 (None)40 RETURN_VALUE
Before optimization:0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)
6 COMPARE_OP 1 (<=)
9 POP_JUMP_IF_FALSE 16
12 LOAD_FAST 1 (accum)
15 RETURN_VALUE
>> 16 LOAD_GLOBAL 0 (f)
19 LOAD_FAST 0 (n)
22 LOAD_CONST 1 (1)
25 BINARY_SUBTRACT
26 LOAD_FAST 1 (accum)
29 LOAD_FAST 0 (n)
32 BINARY_MULTIPLY
33 CALL_FUNCTION 2
36 RETURN_VALUE
After optimization:>> 0 LOAD_FAST 0 (n)
3 LOAD_CONST 1 (1)
6 COMPARE_OP 1 (<=)
9 POP_JUMP_IF_FALSE 16
12 LOAD_FAST 1 (accum)
15 RETURN_VALUE
>> 16 LOAD_FAST 0 (n)
19 LOAD_CONST 1 (1)
22 BINARY_SUBTRACT
23 LOAD_FAST 1 (accum)
26 LOAD_FAST 0 (n)
39 BINARY_MULTIPLY
30 STORE_FAST 1 (accum)
33 STORE_FAST 0 (n)
36 JUMP_ABSOLUTE 0
39 RETURN_VALUE
Simplified algorithm.def recursion_optimizer(f):
new_bytecode = ‘’
for byte in f.__code__.co_code:
if instruction[byte] == ‘LOAD_GLOBAL f’:
get rid of this instruction
elif instruction[byte] == ‘CALL_FUNCTION’:
#replace it with resetting variables and jumping to 0
for arg in *args:
new_bytecode.add_instr(store_new_val(arg))
new_bytecode.add_instr(jump_to_0)
else: #regular byte
new_bytecode.add(byte)
f.__code__.co_code = new_bytecode
return f
Not only does it work, it works FASTER than the original function:
• Timed 10000 calls to fact(450).
Original fact: 1.7009999752
Optimized fact: 1.6970000267
• And faster than other ways of optimizing this.
Here is the most interesting so far:
If our function calls another function…
def sq(x): return x*x
@tailbytes_v1def sum_squares(n, accum):
if n < 1:return accum
else:return sum_squares(n-1, accum+sq(n))
• Our initial algorithm was removing all calls to a function, not only the recursive calls, so this would break.
How do you battle this?
• We need to keep track of function calls and remove only the recursive calls.
• Unfortunately, bytecode doesn’t know which function it’s calling: it just calls whatever is on the stack:
29 CALL_FUNCTION 2
So we just need to keep track of the stack…
• When we hit ‘LOAD_GLOBAL self’, we start keeping track of the stack size (stack_size = 0).
• Now, with every byte, we update the stack size.
• Once we hit stack_size = 0, it means this byte was the recursive call, so we remove it.
• It allows us to not get rid of calls to other functions (e.g., identity).
Road ahead:
• Make it harder to break.
• Translate “normal” (non-tail) recursion into tail-recursion (possibly with ASTs)
• Handle mutual recursion
…And some crasy ideas:
https://github.com/lohmataja/recursion
Or: http://tinyurl.com/tailbytes
Liuda Nikolaeva