[PyTorch 学习笔记]
模型创建步骤 与 nn.Module
这篇文章来看下 PyTorch
中网络模型的创建步骤。网络模型的内容如下,包括模型创建和权值初始化,这些内容都在nn.Module
中有实现。
一、网络模型的创建步骤
创建模型有 2
个要素:构建子模块 和拼接子模块 。如
LeNet
里包含很多卷积层、池化层、全连接层,当我们构建好所有的子模块之后,按照一定的顺序拼接起来。
这里以上一篇文章中 lenet.py
的 LeNet
为例,继承nn.Module
,必须实现__init__()
方法和forward()
方法。其中__init__()
方法里创建子模块,在forward()
方法里拼接子模块。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class LeNet (nn.Module): def __init__ (self, classes ): super (LeNet, self).__init__() self.conv1 = nn.Conv2d(3 , 6 , 5 ) self.conv2 = nn.Conv2d(6 , 16 , 5 ) self.fc1 = nn.Linear(16 *5 *5 , 120 ) self.fc2 = nn.Linear(120 , 84 ) self.fc3 = nn.Linear(84 , classes) def forward (self, x ): out = F.relu(self.conv1(x)) out = F.max_pool2d(out, 2 ) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2 ) out = out.view(out.size(0 ), -1 ) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def __call__ (self, *input , **kwargs ): for hook in self._forward_pre_hooks.values(): result = hook(self, input ) if result is not None : if not isinstance (result, tuple ): result = (result,) input = result if torch._C._get_tracing_state(): result = self._slow_forward(*input , **kwargs) else : result = self.forward(*input , **kwargs) ... ... ...
最终会调用result = self.forward(*input, **kwargs)
函数,该函数会进入模型的forward()
函数中,进行前向传播。
在 torch.nn
中包含 4 个模块,如下图所示。
二、nn.Module
nn.Module
有 8
个属性,都是OrderDict
(有序字典)。在 LeNet
的__init__()
方法中会调用父类nn.Module
的__init__()
方法,创建这
8 个属性。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def __init__ (self ): """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once("python.nn_module" ) self.training = True self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict()
**_parameters 属性**:存储管理 nn.Parameter 类型的参数
**_modules 属性**:存储管理 nn.Module 类型的参数
_buffers 属性:存储管理缓冲属性,如 BN 层中的 running_mean
5 个 ***_hooks 属性:存储管理钩子函数