Data Processing
In the process of model training, we may encounter the problem of overfitting. One solution is to do data augmentation on the training data. By processing the data in a specific way, such as cropping, flipping, and adjusting the brightness of the image, the diversity of the samples is increased, thereby enhancing the generalization ability of the model.
1. Introduction to tensorlayerx.vision.transforms
TensorLayerX framework has built-in dozens of image data processing methods in tensorlayerx.vision.transforms. The following code can be used to view them:
import tensorlayerx
print('Image data processing methods: ', tensorlayerx.vision.transforms.__all__)
Image data processing methods: ['Crop', 'CentralCrop', 'HsvToRgb', 'AdjustBrightness', 'AdjustContrast', 'AdjustHue', 'AdjustSaturation', 'FlipHorizontal', 'FlipVertical', 'RgbToGray', 'PadToBoundingbox', 'Pad', 'Normalize', 'StandardizePerImage', 'RandomBrightness', 'RandomContrast', 'RandomHue', 'RandomSaturation', 'RandomCrop', 'Resize', 'RgbToHsv', 'Transpose', 'Rotation', 'RandomRotation', 'RandomShift', 'RandomShear', 'RandomZoom', 'RandomFlipVertical', 'RandomFlipHorizontal', 'HWC2CHW', 'CHW2HWC', 'ToTensor', 'Compose', 'RandomResizedCrop', 'RandomAffine', 'ColorJitter', 'Rotation']
Including common operations such as image random cropping, image rotation transformation, changing image brightness, changing image contrast, etc. The introduction of each operation method can be found in the API documentation.
For the data preprocessing methods built-in in the TensorlayerX framework, they can be called individually or combined to use multiple data preprocessing methods. The specific usage is as follows:
Use individually
from tensorlayerx.vision.transforms import Resize
transform = Resize(size = (100,100), interpolation='bilinear')
Use multiple combinations
In this case, we need to define each data processing method first, and then use Compose to combine them.
from tensorlayerx.vision.transforms import (
Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop
)
transforms = Compose(
[
RandomCrop(size=[24, 24]),
RandomFlipHorizontal(),
RandomBrightness(brightness_factor=(0.5, 1.5)),
RandomContrast(contrast_factor=(0.5, 1.5)),
StandardizePerImage()
]
)
2. Apply data preprocessing operations in the dataset
After defining the data processing method, it can be directly applied in the dataset Dataset. The following introduces the application of data preprocessing in the custom dataset.
For custom datasets, the defined data processing method can be passed into the __init__ function in the dataset, and defined as an attribute of the custom dataset class. Then apply it to the image in __getitem__, as shown in the following code:
# TensorLayerX will automatically downloads and loads the MNIST dataset
print('download training data and load training data')
X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
X_train = X_train * 255
print('load finished')
class MNISTDataset(Dataset):
"""
Step 1: Inherit the tensorlayerx.dataflow.Dataset class
"""
def __init__(self, data=X_train, label=y_train, transform=transform):
"""
Step 2: Implement the __init__ function to initialize the dataset and map the samples and labels to the list
"""
self.data = data
self.label = label
self.transform = transform
def __getitem__(self, index):
"""
Step 3: Implement the __getitem__ function to define how to get data at the specified index and return a single data (sample data, corresponding label)
"""
data = self.data[index].astype('float32')
data = self.transform(data)
label = self.label[index].astype('int64')
return data, label
def __len__(self):
"""
Step 4: Implement the __len__ function to return the total number of samples in the dataset
"""
return len(self.data)
transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='HWC')])
train_dataset = MNISTDataset(data=X_train, label=y_train, transform=transform)
3. Introduction to several data preprocessing methods
The effect of the data processing method built in TensorLayerX can be easily compared by visualization. The following introduces a comparison example of several methods.
First, download the example image
# Download example image
wget https://paddle-imagenet-models-name.bj.bcebos.com/data/demo_images/flower_demo.png
CentralCrop
Crop the input image and keep the center point of the image unchanged.
import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tensorlayerx.vision.transforms import CentralCrop
transform = CentralCrop(size = (224, 224))
image = cv2.imread('images/flower_demo.png')
image_after_transform = transform(image)
plt.subplot(1,2,1)
plt.title('origin image')
plt.imshow(image[:,:,::-1])
plt.subplot(1,2,2)
plt.title('CenterCrop image')
plt.imshow(image_after_transform[:,:,::-1])
RandomFlipHorizontal
Flip the image horizontally based on the random probability.
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tensorlayerx.vision.transforms import RandomFlipHorizontal
transform = RandomFlipHorizontal(0.5)
image = cv2.imread('images/flower_demo.png')
image_after_transform = transform(image)
plt.subplot(1,2,1)
plt.title('origin image')
plt.imshow(image[:,:,::-1])
plt.subplot(1,2,2)
plt.title('RandomFlipHorizontal image')
plt.imshow(image_after_transform[:,:,::-1])
ColorJitter
Adjust the brightness, contrast, saturation and hue of the image randomly.
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tensorlayerx.vision.transforms import ColorJitter
transform = ColorJitter(brightness=(1,5), contrast=(1,5), saturation=(1,5), hue=(-0.2,0.2))
image = cv2.imread('images/flower_demo.png')
image_after_transform = transform(image)
plt.subplot(1,2,1)
plt.title('origin image')
plt.imshow(image[:,:,::-1])
plt.subplot(1,2,2)
plt.title('ColorJitter image')
plt.imshow(image_after_transform[:,:,::-1])
More data processing method introduction can refer to tensorlayerx.vision.transforms API documentation.