Add(3)(5) nn.Sequential. How it works?
Solution 1:
You understand it right, the Sequential
class in a nutshell just calls provided modules one by one. Here's the code of the forward method
def forward(self, input):
for module in self:
input = module(input)
return input
here, for module in self
just iterates through the modules provided in a constructor (Sequential.__iter__
method in charge of it).
Sequential module calls this method when you call it using ()
syntax.
calculator = nn.Sequential(
Add(3),
Add(2),
Add(5),
)
output = calculator(torch.tensor([1]))
But how does it work? In python
, you could make objects of a class callable, if you add __call__
method to the class definition. If the class does not contain this method explicitly, it means that it possibly was inherited from a superclass. In case of Add
and Sequential
, it's Module
class that implements __call__
method. And __call__
method calls 'public' forward
method defined by a user.
It could be confusing that python uses the same syntax for the object instantiation and function or method call. To make a difference visible to the reader, python
uses naming conventions. Classes should be named in a CamelCase with a first capital letter, and objects in a snake_case (it's not obligatory, but it's the rule that better to follow).
Just like in you example, Add
is a class and add
is a callable object of this class:
add = Add(torch.tensor([1]))
So, you can call add
just like you have called a calculator in you example.
>>> add = Add(torch.tensor([1]))
>>> add(2)
Out: tensor([3])
But that won't work:
>>> add = Add(torch.tensor([1]))
>>> add(2)(1)
Out:
----> 3 add(2)(1)
TypeError: 'Tensor' object is not callable
That means that add(2)
returns a Tensor
object that does not implement __call__
method.
Compare this code with
>>> Add(torch.tensor([1]))(2)
Out:
tensor([3])
This code is the same as the first example, but rearranged a little bit.
--
To avoid confusion, I usually name objects differently: like add_obj = Add(1)
. It helps me to highlight a difference.
If you are not sure what you're working with, use functions type
and isinstance
. They would help to find out what's going on.
For example, if you check the add
object, you could see that it's a callable object (i.e., it implements __call__
)
>>> from typing import Callable
>>> isinstance(add, Callable)
True
And for a tensor:
>>> from typing import Callable
>>> isinstance(add, torch.tensor(1))
False
Hence, it will rase TypeError: 'Tensor' object is not callable
in case you call it.
If you'd like to understand how python double-under methods like init or call work, you could read this page that describes python data model
(It could be a bit tedious, so you could prefer to read something like Fluent Python or other book)
Solution 2:
The core thing to remember is that when you instantiate a Module
class, you are creating a callable object, i.e. something that can behave like a function.
In plain English and step by step:
- When you write something like
add5 = Add(5)
, what you are doing is assigning an "instance" of the PyTorch modelAdd
toadd5
- More specifically, you are passing
5
to theAdd
class's__init__
method, and so itsvalue
attribute is set to5
. - PyTorch
Module
s are "callable" meaning you can call them like functions. The function you call when you use an instance of aModule
is that instance'sforward
method. So concretely, with ouradd5
object, if we pass a value,x = 10
, by writing something likeadd5(10)
it is like we ranx + add5.value
, which equals10 + 5 = 15
.
Now putting this together, we should view the Sequential
interface for building neural network models that don't have branching structures as just sequentially invoking each of the instantiated Module
s' forward
methods.
Omitting the definition of Add
and focussing just on calculator
as a series of computations we have the following (I've added the comments to show you what you should think of at each step)
calculator = nn.Sequential(
# given some input tensor x...
Add(3), # run: x = x + self.value with self.value = 3
Add(2), # run: x = x + self.value with self.value = 2
Add(5), # run: x = x + self.value with self.value = 5
)
Now we can see that it's reasonable to expect that if we pass the value 1
(albeit wrapped up as a PyTorch Tensor
) we are just doing 1 + 3 + 2 + 5
which of course equals 11
. PyTorch returns us back the value still as a Tensor
object.
x = torch.tensor([1])
output = calculator(x)
print(output) # tensor([11])
Finally, Add(3)(5)
* works for exactly this same reason! With Add(3)
we are getting an instance of the Add
class with the value to add being 3. We then use it immediately with the somewhat unintuitive syntax Add(3)(5)
to return the value 3 + 5 = 8
.
*I think you intended the capitalised class name, not an instance of the class