pytorch RuntimeError:给定groups=1,权重大小为[32,3,5,5],预期输入[1,32,3,784]有3个通道,但实际上得到了32个通道

g0czyy6m  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(194)

RuntimeError:给定groups=1,权重大小为[32,3,5,5],期望输入[1,32,3,784]有3个通道,但得到了32个通道,这是我的代码:

class Conv(nn.Module):
  def __init__(self):
    super().__init__()
    #self.flatten=nn.Flatten()
    self.conv1=nn.Conv2d(in_channels=3,out_channels=32,kernel_size=5,padding=0,stride=1)
    self.relu1=nn.ReLU()
    self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)
    self.conv2=nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,padding=0,stride=1)
    self.relu2=nn.ReLU()
    self.pool2=nn.MaxPool2d(kernel_size=2,stride=2)
    self.flatten=nn.Flatten()
    self.fc1=nn.Linear(in_features=32*4*4,out_features=128)
    self.relu3=nn.ReLU()
    self.fc2=nn.Linear(in_features=128,out_features=64)
    self.relu4=nn.ReLU()
    self.fc3=nn.Linear(in_features=64,out_features=7)
    self.logSoftmax=nn.LogSoftmax(dim=1)
  def forward(self,x):
    x=self.conv1(x)
    x=self.relu1(x)
    x=self.pool1(x)
    x=self.conv2(x)
    x=self.relu2(x)
    x=self.pool2(x)
    x=self.flatten(x)
    x=self.fc1(x)
    x=self.relu3(x)
    x=self.fc2(x)
    x=self.relu4(x)
    x=self.fc3(x)
    out=self.logSoftmax(x)
    return out```

字符串
错误:enter image description here
这是数据的形状:enter image description here
请帮我解决这个问题。谢谢!
我尝试了1通道,但仍然得到错误

wmvff8tz

wmvff8tz1#

错误消息表明输入中的通道数与第一个卷积层的权重之间不匹配。
如果您的输入数据形状为[3,784](在图像中),则意味着您将整个图像视为三个通道中每个通道都有784个元素的平面向量。在这种情况下,您必须重新塑造输入数据,以获得具有3个通道的2D图像的适当形状。这里是解决您问题的更新版本

import torch
import torch.nn as nn

class Conv(nn.Module):
    def __init__(self):
        super(Conv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=0, stride=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=0, stride=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=32 * 4 * 4, out_features=128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(in_features=64, out_features=7)
        self.logSoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        # Reshape the input to [batch_size, channels, height, width]
        x = x.view(-1, 3, 28, 28)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        out = self.logSoftmax(x)
        return out

model = Conv()

input_data = torch.randn((3,784))
print(input_data.shape)

output = model(input_data)
print(output)

字符串

相关问题