Fix Python – What does model.train() do in PyTorch?

Question

Asked By – aerin

Does it call forward() in nn.Module? I thought when we call the model, forward method is being used.
Why do we need to specify train()?

Now we will see solution for issue: What does model.train() do in PyTorch?


Answer

model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details:
model.train() sets the mode to train
(see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing.
It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.

This question is answered By – Umang Gupta

This answer is collected from stackoverflow and reviewed by FixPython community admins, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0