Pytorch(16)模型GPU训练

[PyTorch 学习笔记] 模型应用

  • Neural Network Malware Binary Classification:https://github.com/jaketae/deep-malware-detection

下面的代码是使用 Generator 来生成人脸图像,Generator 已经训练好保存在 pkl 文件中,只需要加载参数即可。由于模型是在多 GPU 的机器上训练的,因此加载参数后需要使用remove_module()函数来修改state_dict中的key

1
2
3
4
5
6
7
8
9
10
11
# 多 GPU 的机器上训练模型参数修改
def remove_module(state_dict_g):
# remove module.
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict_g.items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v

return new_state_dict

在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里==不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异==。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。