5.3 PyTorch修改模型

5.3 PyTorch修改模型

5.3.2 添加外部输入#

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加的输入和后续层之间的连接关系,从而完成模型的修改。

我们以torchvision的resnet50模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量add_variable来辅助预测。具体实现如下:

class Model(nn.Module):

def __init__(self, net):

super(Model, self).__init__()

self.net = net

self.relu = nn.ReLU()

self.dropout = nn.Dropout(0.5)

self.fc_add = nn.Linear(1001, 10, bias=True)

self.output = nn.Softmax(dim=1)

def forward(self, x, add_variable):

x = self.net(x)

x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1)

x = self.fc_add(x)

x = self.output(x)

return x

这里的实现要点是通过torch.cat实现了tensor的拼接。torchvision中的resnet50输出是一个1000维的tensor,我们通过修改forward函数(配套定义一些层),先将1000维的tensor通过激活函数层和dropout层,再和外部输入变量"add_variable"拼接,最后通过全连接层映射到指定的输出维度10。

另外这里对外部输入变量"add_variable"进行unsqueeze操作是为了和net输出的tensor保持维度一致,常用于add_variable是单一数值 (scalar) 的情况,此时add_variable的维度是 (batch_size, ),需要在第二维补充维数1,从而可以和tensor进行torch.cat操作。对于unsqueeze操作可以复习下2.1节的内容和配套代码。

之后对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()

model = Model(net).cuda()

另外别忘了,训练中在输入数据的时候要给两个inputs:

outputs = model(inputs, add_var)

相关推荐

ECCO哪个国家海淘最便宜?ECCO美国价格便宜吗?
365bet亚洲唯一官网

ECCO哪个国家海淘最便宜?ECCO美国价格便宜吗?

📅 09-12 👁️ 3262
香煎鲷鱼片
365bet亚洲唯一官网

香煎鲷鱼片

📅 09-05 👁️ 7701