imgs = [img for img, _ in dataset]
stack_imgs(imgs[:4]).resize((800, 120))
display_batches(data_loader)
display_batches(data_loader_test)
When create dataset loaders, we must pass two functions returning transformations on an image and on tensors.
data_loader, data_loader_val = get_dataset("segmentation", batch_size=4)
display_batches(data_loader, n_batches=3)
When create dataset loaders, we must pass two functions returning transformations on an image and on tensors.
data_loader_test = get_test_dataset("segmentation", batch_size=4)
display_batches(data_loader_test, n_batches=3)
In order to prevent overfitting which happens when the dataset size is too small, we perform a number of transformations to increase the size of the dataset. One transofrmation implemented in the Torch vision library is RandomHorizontalFlip
and we will implemented MyColorJitter
which is basically just a wrapper around torchvision.transforms.ColorJitter
class. However, we cannot use this class directly without a wrapper because a transofrmation could possibly affect targets and not just the image. For example, if we were to implement RandomCrop, we would need to crop segmentation masks and readjust bounding boxes as well.
This is an example on how to construct get_tensor_transforms
needed to construct a data_loader
object using get_dataset
function:
class MyColorJitter:
def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5):
self.torch_color_jitter = torchvision.transforms.ColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
)
def __call__(self, image, target):
image = self.torch_color_jitter(image)
return image, target
def get_my_tensor_transforms(train):
transforms = []
# converts the image, a PIL image, into a PyTorch Tensor
transforms.append(ToTensor())
if train:
# during training, randomly flip the training images
# and ground-truth for data augmentation
transforms.append(RandomHorizontalFlip(0.5))
transforms.append(MyColorJitter())
# TODO: add additional transforms: e.g. random crop
return Compose(transforms)
data_loader, data_loader_test = get_dataset("segmentation", batch_size=2, get_tensor_transforms=get_my_tensor_transforms)
display_batches(data_loader, n_batches=2)
ColorJitter
and RandomCrop
etc. (https://pytorch.org/docs/stable/torchvision/transforms.html)