Python list comprehension with lambdas [duplicate]

I'm running Python 3.4.2, and I'm confused at the behavior of my code. I'm trying to create a list of callable polynomial functions with increasing degree:

bases = [lambda x: x**i for i in range(3)]

But for some reason it does this:

print([b(5) for b in bases])
# [25, 25, 25]

Why is bases seemingly a list of the last lambda expression, in the list comprehension, repeated?


Solution 1:

The problem, which is a classic "gotcha", is that the i referenced in the lambda functions is not looked up until the lambda function is called. At that time, the value of i is the last value it was bound to when the for-loop ended, i.e. 2.

If you bind i to a default value in the definition of the lambda functions, then each i becomes a local variable, and its default value is evaluated and bound to the function at the time the lambda is defined rather than called.

Thus, when the lambda is called, i is now looked up in the local scope, and its default value is used:

In [177]: bases = [lambda x, i=i: x**i for i in range(3)]

In [178]: print([b(5) for b in bases])
[1, 5, 25]

For reference:

  • Python scopes and namespaces

Solution 2:

a more 'pythonic' approach:
using nested functions:

def polyGen(degree):
    def degPolynom(n):
        return n**degree
    return degPolynom

polynoms = [polyGen(i) for i in range(5)]
[pol(5) for pol in polynoms]

output:

>> [1, 5, 25, 125, 625]

Solution 3:

As an alternate solution, you could use a partial function:

>>> bases = [(lambda i: lambda x: x**i)(i) for i in range(3)]
>>> print([b(5) for b in bases])
[1, 5, 25]

The only advantage of that construction over the classic solution given by @unutbu is that way, you cannot introduce sneaky bugs by calling your function with the wrong number of arguments:

>>> print([b(5, 8) for b in bases])
#             ^^^
#             oups
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 1, in <listcomp>
TypeError: <lambda>() takes 1 positional argument but 2 were given

As suggested by Adam Smith in a comment bellow, instead of using "nested lambda" you could use functools.partial with the same benefit:

>>> import functools
>>> bases = [functools.partial(lambda i,x: x**i,i) for i in range(3)]
>>> print([b(5) for b in bases])
[1, 5, 25]
>>> print([b(5, 8) for b in bases])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 1, in <listcomp>
TypeError: <lambda>() takes 2 positional arguments but 3 were given

Solution 4:

I don't think the "why this happens" aspect of the question has been answered yet.

The reason that names non-local names in a function are not considered constants is so that these non-local names will match the behaviour of global names. That is, changes to a global name after a function is created are observed when the function is called.

eg.

# global context
n = 1
def f():
    return n
n = 2
assert f() == 2

# non-local context
def f():
    n = 1
    def g():
        return n
    n = 2
    assert g() == 2
    return g
assert f()() == 2

You can see that in both the global and non-local contexts that if the value of a name is changed, then that change is reflected in future invocations of the function that references the name. If globals and non-locals were treated differently then that would be confusing. Thus, the behaviour is made consistent. If you need the current value of a name to made constant for a new function then the idiomatic way is to delegate the creation of the function to another function. The function is created in the creating-function's scope (where nothing changes), and thus the value of the name will not change.

eg.

def create_constant_getter(constant):
    def constant_getter():
        return constant
    return constant_getter

getters = [create_constant_getter(n) for n in range(5)]
constants = [f() for f in getters]
assert constants == [0, 1, 2, 3, 4]

Finally, as an addendum, functions can modify non-local names (if the name is marked as such) just as they can modify global names. eg.

def f():
    n = 0
    def increment():
        nonlocal n
        n += 1
        return n
    return increment
g = f()
assert g() + 1 == g()