问题描述
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingNet(nn.Module):
def __init__(self):
super(EmbeddingNet,self).__init__()
self.convnet = nn.Sequential(nn.Conv2d(1,32,5),nn.PReLU(),nn.MaxPool2d(2,stride=2),nn.Conv2d(32,64,stride=2))
self.fc = nn.Sequential(nn.Linear(64 * 4 * 4,256),nn.Linear(256,2)
)
def forward(self,x):
output = self.convnet(x)
output = output.view(output.size()[0],-1)
output = self.fc(output)
return output
def get_embedding(self,x):
return self.forward(x)
class EmbeddingNetL2(EmbeddingNet):
def __init__(self):
super(EmbeddingNetL2,self).__init__()
def forward(self,x):
output = super(EmbeddingNetL2,self).forward(x)
output /= output.pow(2).sum(1,keepdim=True).sqrt()
return output
def get_embedding(self,x):
return self.forward(x)'''enter code here
解决方法
错误很简单。它说你给了 3 个通道的图像而不是 1 个通道。
一个变化将在这个块
class EmbeddingNet(nn.Module):
def __init__(self):
super(EmbeddingNet,self).__init__()
self.convnet = nn.Sequential(nn.Conv2d(3,32,5),#instead of 1 i have made it 3
nn.PReLU(),nn.MaxPool2d(2,stride=2),nn.Conv2d(32,64,nn.PReLU(),stride=2))
self.fc = nn.Sequential(nn.Linear(64 * 4 * 4,256),nn.Linear(256,2)
)
编辑下一个错误:
改成这个
self.fc = nn.Sequential(nn.Linear(64 * 61 * 61,#here is the change
nn.PReLU(),2)
)