如何获取我放入Pytorch的Dataloader中的图像的文件名

问题描述

我使用pytorch加载这样的图像:

inf_data = InfDataLoader(img_folder=args.imgs_folder,target_size=args.img_size)
inf_DataLoader = DataLoader(inf_data,batch_size=1,shuffle=True,num_workers=2)

然后:

    with torch.no_grad():
        for batch_idx,(img_np,img_tor) in enumerate(inf_DataLoader,start=1):

            img_tor = img_tor.to(device)
            pred_masks,_ = model(img_tor)

但是我想获取图像的文件名。谁能帮我这个? 非常感谢!

解决方法

DataLoader基本上无法获得文件名。但是在Dataset(即上述问题中的InfDataloader)中,您可以从张量中获取文件名。

class InfDataloader(Dataset):
    """
    Dataloader for Inference.
    """
    def __init__(self,img_folder,target_size=256):
        self.imgs_folder = img_folder

        self.img_paths = []

        img_path = self.imgs_folder + '/'
        img_list = os.listdir(img_path)
        img_list.sort()
        img_list.sort(key=lambda x: int(x[:-4]))  ##文件名按数字排序
        img_nums = len(img_list)
        for i in range(img_nums):
            img_name = img_path + img_list[i]
            self.img_paths.append(img_name)

        # self.img_paths = sorted(glob.glob(self.imgs_folder + '/*'))

        print(self.img_paths)


        self.target_size = target_size
        self.normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])

    def __getitem__(self,idx):
        """
        __getitem__ for inference
        :param idx: Index of the image
        :return: img_np is a numpy RGB-image of shape H x W x C with pixel values in range 0-255.
        And img_tor is a torch tensor,RGB,C x H x W in shape and normalized.
        """
        img = cv2.imread(self.img_paths[idx])
        name = self.img_paths[idx]

        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

        # Pad images to target size
        img_np = pad_resize_image(img,None,self.target_size)
        img_tor = img_np.astype(np.float32)
        img_tor = img_tor / 255.0
        img_tor = np.transpose(img_tor,axes=(2,1))
        img_tor = torch.from_numpy(img_tor).float()
        img_tor = self.normalize(img_tor)

        return img_np,img_tor,name

我在这里添加行 name = self.img_paths[idx] 并退回。

所以

 with torch.no_grad():
        for batch_idx,(img_np,name) in enumerate(inf_dataloader,start=1):
            img_tor = img_tor.to(device)
            pred_masks,_ = model(img_tor)

我能知道这个名字。