Add(3)(5) nn.Sequential. How it works?

Solution 1:

You understand it right, the Sequential class in a nutshell just calls provided modules one by one. Here's the code of the forward method

    def forward(self, input):
        for module in self:
            input = module(input)
        return input

here, for module in self just iterates through the modules provided in a constructor (Sequential.__iter__ method in charge of it).

Sequential module calls this method when you call it using () syntax.

calculator = nn.Sequential(
    Add(3),
    Add(2),
    Add(5),
)
output = calculator(torch.tensor([1]))

But how does it work? In python, you could make objects of a class callable, if you add __call__ method to the class definition. If the class does not contain this method explicitly, it means that it possibly was inherited from a superclass. In case of Add and Sequential, it's Module class that implements __call__ method. And __call__ method calls 'public' forward method defined by a user.

It could be confusing that python uses the same syntax for the object instantiation and function or method call. To make a difference visible to the reader, python uses naming conventions. Classes should be named in a CamelCase with a first capital letter, and objects in a snake_case (it's not obligatory, but it's the rule that better to follow).

Just like in you example, Add is a class and add is a callable object of this class:

add = Add(torch.tensor([1]))

So, you can call add just like you have called a calculator in you example.

>>> add = Add(torch.tensor([1]))
>>> add(2)
Out: tensor([3]) 

But that won't work:

>>> add = Add(torch.tensor([1]))
>>> add(2)(1)
Out: 
----> 3 add(2)(1)
TypeError: 'Tensor' object is not callable

That means that add(2) returns a Tensor object that does not implement __call__ method.

Compare this code with

>>> Add(torch.tensor([1]))(2)
Out:
tensor([3])  

This code is the same as the first example, but rearranged a little bit.

--

To avoid confusion, I usually name objects differently: like add_obj = Add(1). It helps me to highlight a difference.

If you are not sure what you're working with, use functions type and isinstance. They would help to find out what's going on.

For example, if you check the add object, you could see that it's a callable object (i.e., it implements __call__)

>>> from typing import Callable
>>> isinstance(add, Callable)
True

And for a tensor:

>>> from typing import Callable
>>> isinstance(add, torch.tensor(1))
False

Hence, it will rase TypeError: 'Tensor' object is not callable in case you call it.

If you'd like to understand how python double-under methods like init or call work, you could read this page that describes python data model

(It could be a bit tedious, so you could prefer to read something like Fluent Python or other book)

Solution 2:

The core thing to remember is that when you instantiate a Module class, you are creating a callable object, i.e. something that can behave like a function.

In plain English and step by step:

  • When you write something like add5 = Add(5), what you are doing is assigning an "instance" of the PyTorch model Add to add5
  • More specifically, you are passing 5 to the Add class's __init__ method, and so its value attribute is set to 5.
  • PyTorch Modules are "callable" meaning you can call them like functions. The function you call when you use an instance of a Module is that instance's forward method. So concretely, with our add5 object, if we pass a value, x = 10, by writing something like add5(10) it is like we ran x + add5.value, which equals 10 + 5 = 15.

Now putting this together, we should view the Sequential interface for building neural network models that don't have branching structures as just sequentially invoking each of the instantiated Modules' forward methods.

Omitting the definition of Add and focussing just on calculator as a series of computations we have the following (I've added the comments to show you what you should think of at each step)

calculator = nn.Sequential(
            # given some input tensor x...
    Add(3), # run: x = x + self.value with self.value = 3
    Add(2), # run: x = x + self.value with self.value = 2 
    Add(5), # run: x = x + self.value with self.value = 5
)

Now we can see that it's reasonable to expect that if we pass the value 1 (albeit wrapped up as a PyTorch Tensor) we are just doing 1 + 3 + 2 + 5 which of course equals 11. PyTorch returns us back the value still as a Tensor object.

x = torch.tensor([1])
output = calculator(x)
print(output) # tensor([11])

Finally, Add(3)(5)* works for exactly this same reason! With Add(3) we are getting an instance of the Add class with the value to add being 3. We then use it immediately with the somewhat unintuitive syntax Add(3)(5) to return the value 3 + 5 = 8.

*I think you intended the capitalised class name, not an instance of the class