MENU

pytorch 自定义 transforms

October 30, 2018 • Read: 5042 • pytorch阅读设置

背景

上一篇博客介绍了 pytorchtransforms 模块,有需要的移步 研究一下 pytorch 的 transforms 模块。现在,我想像 T.Normalise() 一样,对每个图片做一下对比度增强,而其他的转换方法保持不变,依旧采用随机处理。

实现方式是 采用 __call__ 机制,具体如下:

from PIL import ImageEnhance

class Contrast(object):
    def __init__(self,degree):
        self.degree = degree
    def __call__(self,img):
        return contrast(img,self.degree)
def contrast(img,degree):
    enh_contrast = ImageEnhance.Contrast(img)
    enh_contrast.enhance(degree)
    return img

上面定义了新的数据转换方法:Contrast,使用方法如下,以开源的代码为例 class ChaojieDataset(Dataset):


#2.define dataset
class ChaojieDataset(Dataset):
    def __init__(self,label_list,transforms=None,train=True,test=False):
        self.test = test 
        self.train = train 
        imgs = []
        if self.test:
            for index,row in label_list.iterrows():
                imgs.append((row["filename"]))
            self.imgs = imgs 
        else:
            for index,row in label_list.iterrows():
                imgs.append((row["filename"],row["label"]))
            self.imgs = imgs
        if transforms is None:
            if self.test or not train:
                self.transforms = T.Compose([
                    T.Resize((config.img_weight,config.img_height)),
                    T.ToTensor(),
                    T.Normalize(mean = [0.485,0.456,0.406],
                                std = [0.229,0.224,0.225])])
            else:
                self.transforms  = T.Compose([
                    T.Resize((config.img_weight,config.img_height)),
                    T.RandomRotation(30),
                    T.RandomHorizontalFlip(),
                    T.RandomVerticalFlip(),
                    T.RandomAffine(45),
                    Contrast(1.8), ## 在此添加 ##
                    T.ToTensor(),
                    T.Normalize(mean = [0.485,0.456,0.406],
                                std = [0.229,0.224,0.225])])
        else:
            self.transforms = transforms
    def __getitem__(self,index):
        if self.test:
            filename = self.imgs[index]
            img = Image.open(filename)
            img = self.transforms(img)
            return img,filename
        else:
            filename,label = self.imgs[index] 
            img = Image.open(filename)
            img = self.transforms(img)
            return img,label
    def __len__(self):
        return len(self.imgs)

这样,每一张图片除了进行归一化和转变成张量外,都做了对比度增强,当然也可以设置一个随机数,每次随机选择要增强的对比度。

Last Modified: January 7, 2019
Archives Tip
QR Code for this page
Tipping QR Code