如何以支持自动渐变的方式围绕其中心旋转PyTorch图像张量?

问题描述

我想绕其中心随机旋转图像张量(B,C,H,W)(我认为是二维旋转)。我想避免使用NumPy和Kornia,因此我基本上只需要从Torch模块中导入即可。我也没有使用torchvision.transforms,因为我需要它与autograd兼容。基本上,我正在尝试为DeepDream等可视化技术创建torchvision.transforms.RandomRotation()的自动分级兼容版本(因此,我需要尽可能避免工件)。

import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name,image_size):
    Loader = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor,output_name):
    output_tensor.clamp_(0,1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor,radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees,r_degrees))
n = random.randint(angle_range[0],angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file',(512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor,ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor,'rotated_image.jpg')

我想要完成的一些示例输出:

First example of rotated image

Second example of rotated image

解决方法

因此,网格生成器和采样器是空间变压器的子模块(JADERBERG,Max等)。这些子模块是不可训练的,它们使您可以应用可学习的以及不可学习的空间转换。 在这里,我将使用这两个子模块,并使用它们使用PyTorch的函数thetaF.affine_grid(分别是生成器和采样器的实现)来F.affine_sample旋转图像:>

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def get_rot_mat(theta):
    theta = torch.tensor(theta)
    return torch.tensor([[torch.cos(theta),-torch.sin(theta),0],[torch.sin(theta),torch.cos(theta),0]])


def rot_img(x,theta,dtype):
    rot_mat = get_rot_mat(theta)[None,...].type(dtype).repeat(x.shape[0],1,1)
    grid = F.affine_grid(rot_mat,x.size()).type(dtype)
    x = F.grid_sample(x,grid)
    return x


#Test:
dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype,range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im,np.pi/2,dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,0)/255)

在上面的示例中,假设我们将图像im当作裙子上的跳舞猫: enter image description here

rotated_im将是一条90度逆时针旋转的裙子跳舞猫:

enter image description here

这就是我们用rot_img等式theta来调用np.pi/4的结果: enter image description here

最棒的是,它与输入是可区分的,并具有自动分级支持!哇!

,

有一个pytorch函数:

x = torch.tensor([[0,1],[2,3]])

x = torch.rot90(x,[0,1])
>> tensor([[1,3],2]])

以下是文档:https://pytorch.org/docs/stable/generated/torch.rot90.html

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...