如何使PyTorch张量B,C,H,W平铺和混合代码更简单,更有效?

问题描述

所以,我几个月前写了下面的代码,并且运行良好。尽管我一直在努力简化并提高效率。

下面的功能将图像张量(B,C,H,W)分割成相等大小的图块(B,C,H,W),然后您可以对图块分别进行处理以节省内存。然后,当从图块重建张量时,它使用蒙版以确保将图块无缝地重新融合在一起。当最右边一列的图块或底行的图块不能使用与其他图块相同的重叠时,遮罩功能中的“特殊遮罩”进行处理。这意味着右边缘图块和底部图块有时可能几乎看不见其内容。这样做是为了确保图块始终精确地指定大小,而与原始图像/张量的大小无关(对于可视化/ DeepDream,神经样式转换等很重要)。边缘行/列的相邻行/列在与边缘行/列重叠的地方也具有特殊的遮罩。

每个图块有8个可能的蒙版,并且这些蒙版中的4个可以一次使用。四种可能的遮罩分别为左,右,上和下,每个遮罩都有一个特殊版本。

# Improved version of: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
import torch


# Apply blend masks to tiles
def mask_tile(tile,overlap,side='bottom'):
    c,h,w = tile.size(1),tile.size(2),tile.size(3)
    top_overlap,bottom_overlap,right_overlap,left_overlap = overlap[0],overlap[1],overlap[2],overlap[3]

    base_mask = torch.ones_like(tile)

    if 'left' in side and 'left-special' not in side:
        lin_mask_left = torch.linspace(0,1,left_overlap,device=tile.device).repeat(h,1).repeat(c,1).unsqueeze(0)
        base_mask[:,:,:left_overlap] = base_mask[:,:left_overlap] * lin_mask_left
    if 'right' in side and 'right-special' not in side:
        lin_mask_right = torch.linspace(1,w-right_overlap:] = base_mask[:,w-right_overlap:] * lin_mask_right
    if 'top' in side and 'top-special' not in side:
        lin_mask_top = torch.linspace(0,top_overlap,device=tile.device).repeat(w,1).rot90(3).repeat(c,:top_overlap,:] = base_mask[:,:] * lin_mask_top
    if 'bottom' in side and 'bottom-special' not in side:
        lin_mask_bottom = torch.linspace(1,h-bottom_overlap:,:] * lin_mask_bottom

    if 'left-special' in side:
        lin_mask_left = torch.linspace(0,device=tile.device)
        zeros_mask = torch.zeros(w-(left_overlaP*2),device=tile.device)
        ones_mask = torch.ones(left_overlap,device=tile.device)
        lin_mask_left = torch.cat([zeros_mask,lin_mask_left,ones_mask],0).repeat(h,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_left
    if 'right-special' in side:
        lin_mask_right = torch.linspace(1,device=tile.device)
        ones_mask = torch.ones(w-right_overlap,device=tile.device)
        lin_mask_right = torch.cat([ones_mask,lin_mask_right],1).unsqueeze(0)
        base_mask = base_mask * lin_mask_right
    if 'top-special' in side:
        lin_mask_top = torch.linspace(0,device=tile.device)
        zeros_mask = torch.zeros(h-(top_overlaP*2),device=tile.device)
        ones_mask = torch.ones(top_overlap,device=tile.device)
        lin_mask_top = torch.cat([zeros_mask,lin_mask_top,0).repeat(w,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_top
    if 'bottom-special' in side:
        lin_mask_bottom = torch.linspace(1,device=tile.device)
        ones_mask = torch.ones(h-bottom_overlap,device=tile.device)
        lin_mask_bottom = torch.cat([ones_mask,lin_mask_bottom],1).unsqueeze(0)
        base_mask = base_mask * lin_mask_bottom
        
    # Apply mask to tile and return masked tile
    return tile * base_mask


def add_tiles(tiles,base_img,tile_coords,tile_size,overlap):

    # Check for any tiles that need different overlap values
    r,c = len(tile_coords[0]),len(tile_coords[1])
    f_ovlp = (tile_coords[0][r-1] - tile_coords[0][r-2],tile_coords[1][c-1] - tile_coords[1][c-2])

    h,w = tiles[0].size(2),tiles[0].size(3)
    t=0
    column,row,= 0,0
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            mask_sides=''
            c_overlap = overlap.copy()
            if row == 0:
                if row == len(tile_coords[0]) - 2:
                    mask_sides += 'bottom-special'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                else:
                    mask_sides += 'bottom'
            elif row > 0 and row < len(tile_coords[0]) -2:
                mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) - 2:
                if f_ovlp[0] > 0:
                    mask_sides += 'bottom-special,top'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) -1:
                if f_ovlp[0] > 0:
                    mask_sides += 'top-special'
                    c_overlap[0] = f_ovlp[0] # Change top overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'top'

            if column == 0:
                if column == len(tile_coords[1]) -2:
                    mask_sides += ',right-special'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                else:
                    mask_sides += ',right'
            elif column > 0 and column < len(tile_coords[1]) -2:
                mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -2:
                if f_ovlp[1] > 0:
                    mask_sides += ',right-special,left'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',left'
            elif column == len(tile_coords[1]) -1:
                if f_ovlp[1] > 0:
                    mask_sides += ',left-special'
                    c_overlap[3] = f_ovlp[1] # Change left overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',left'

            tile = mask_tile(tiles[t],c_overlap,side=mask_sides)
            base_img[:,y:y+tile_size[0],x:x+tile_size[1]] = base_img[:,x:x+tile_size[1]] + tile
            t+=1
            column+=1
        row+=1
        column=0
    return base_img


# Calculate the coordinates for tiles
def get_tile_coords(d,tile_dim,overlap=0):
    move = int(tile_dim * (1-overlap))
    c,tile_start,coords = 1,[0]
    while tile_start + tile_dim < d:
        tile_start = move * c
        if tile_start + tile_dim >= d:
            coords.append(d - tile_dim)
        else:
            coords.append(tile_start)
        c += 1
    return coords


# Calculates info required for tiling
def tile_setup(tile_size,overlap_percent,base_size):
    if type(tile_size) is not tuple and type(tile_size) is not list:
        tile_size = (tile_size,tile_size)
    if type(overlap_percent) is not tuple and type(overlap_percent) is not list:
        overlap_percent = (overlap_percent,overlap_percent)
    x_coords = get_tile_coords(base_size[1],tile_size[1],overlap_percent[1])
    y_coords = get_tile_coords(base_size[0],tile_size[0],overlap_percent[0])
    y_ovlp,x_ovlp = int(tile_size[0] * overlap_percent[0]),int(tile_size[1] * overlap_percent[1])
    return (y_coords,x_coords),[y_ovlp,y_ovlp,x_ovlp,x_ovlp]


# Split tensor into tiles
def tile_image(img,info_only=False):
    tile_coords,_ = tile_setup(tile_size,(img.size(2),img.size(3)))

    # Cut out tiles
    tile_list = []
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            tile = img[:,y:y + tile_size[0],x:x + tile_size[1]]
            tile_list.append(tile)
    return tile_list


# Put tiles back into the original tensor
def rebuild_image(tiles,image_size,overlap_percent):
    base_img = torch.zeros(image_size,device=tiles[0].device)
    tile_coords,overlap = tile_setup(tile_size,(base_img.size(2),base_img.size(3)))
    return add_tiles(tiles,overlap)

上面的代码可以和下面的代码一起测试:

import torchvision.transforms as transforms
from PIL import Image
import random

# Load image
def preprocess_simple(image_name,image_size):
    Loader = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor,output_name):
    output_tensor.clamp_(0,1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)    

test_input = preprocess_simple('tubingen.jpg',(1024,1024))
tile_size=260
overlap_percent=0.5

img_tiles = tile_image(test_input,tile_size=tile_size,overlap_percent=overlap_percent)

random.shuffle(img_tiles) # Comment this out to not randomize tile positions

output_tensor = rebuild_image(img_tiles,test_input.size(),overlap_percent=overlap_percent)
deprocess_simple(output_tensor,'tiled_image.jpg')

我在下面提供了一个示例(上图是原始图像,下图是我将图块随机放回去展示混合系统的时候):

Original Image

Tiled Image with random tile placement

解决方法

我能够删除所有错误并在此处简化代码:https://github.com/ProGamerGov/dream-creator/blob/master/utils/tile_utils.py

仅在两种情况下才需要特殊掩码,它们是我必须修复的rebuild_tensor中的错误。重叠百分比应等于或小于50%。