MENU

ToTensor 到底做了什么?

November 29, 2018 • Read: 3354 • pytorch阅读设置

背景

最近在做一些图片读取上的研究,有单独保存好的三通道图片,现在需要将其整合起来,一开始选择是先创建数组,比如 np.zeros(shape=(512,512,4)),然后每个通道使用 PIL.Image.open() 读取,最后替换创建的原始数组。另外一种方式使用opencv读取,然后使用 np.satck() 整合。

以下是验证结果:

1. PIL.Image.open 和 cv2.imread()读取到的图片数据一致吗?

前提:只读取单通道的图片,所以 opencv 使用 cv2.IMREAD_GRAYSCALE 读取格式,其实就是0

代码及结果:


import numpy as np
import cv2
from PIL import Image

filename = "100_red.png"
flags = cv2.IMREAD_GRAYSCALE

pil_image = np.array(Image.open(filename))  #将 pil 格式的图片转换成 numpy
cv_image = cv2.imread(filename,flags)

print(pil_image == cv_image)

# 结果如下

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

实验证明,读取单通道下,读取结果确实一致,且数据类型都是 np.uint8

那么三通道的图片呢?

因为 opencv 读取的数据通道数在前面,需要做一下转换,验证结果也是一致的,附上结果:

array([[[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        ...,
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],
       ...,

       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        ...,
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]],

       
       [[ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        ...,
        [ True,  True,  True],
        [ True,  True,  True],
        [ True,  True,  True]]])

2. 如果在读取的过程中直接将图片除以 255 报错?

pytorch 如果要利用默认的转变方式,需要将 numpy 格式的数据转换成 PIL ,但是只支持 np.uint8 不支持 np.float 类型。因此,在传入 torchvision.transforms 中的图片时,不能提前做这一步。

那么问题来了,如果不除以 255,图片数据就没办法转变到 [0,1] 之间,而一般使用预训练模型时都需要进行 normalize 操作,究竟是怎么处理的?肯定不能在原值上进行,这个问题也是纠结了我好久,直到我发现了一个问题,normalize 都是在 T.ToTensor() 这一步之后进行的,答案应该就在 ToTensor

3. torchvision.ToTensor() 以及 Normalize()

源码是这么写的:

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

大概的意思就是将 PIL 或者 numpy.ndarray 转换成 tensor 形式,那我们是不是可以直接使用opencv中的数据变换方式,然后使用 torch.tensor(image)
torch.from_numpy(image) 传入网络呢?

可以是可以这么做,只不过这个时候需要单独增加维度,比较麻烦,所以还是建议直接使用 ToTensor() 来做转换。

看看到底做了哪些操作

源码在 F.to_tensor(pic)

说明:源码中不是一步到位操作的,只给出了大概操作类型,详情移步源码。

4. 总结

经过一些列的验证,总算排除了数据方面的不同处理可能带给模型的影响,这下就可以放心去魔改模型了,过段时间想整理一下如何魔改原有模型,以及如何初始化和注意事项。

Last Modified: January 7, 2019
Archives Tip
QR Code for this page
Tipping QR Code