from typing import Dict, List, Optional, Tuple, Union
import torch
import torchvision
from torch import Tensor, nn
from torchvision import ops
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
def _flip_coco_person_keypoints(kps, width):
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
flipped_data = kps[:, flip_inds]
flipped_data[..., 0] = width - flipped_data[..., 0]
# Maintain COCO convention that if visibility == 0, then x, y = 0
inds = flipped_data[..., 2] == 0
flipped_data[inds] = 0
return flipped_data
[docs]
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
[docs]
class RandomHorizontalFlip(T.RandomHorizontalFlip):
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
_, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target
[docs]
class PILToTensor(nn.Module):
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target
[docs]
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target
[docs]
class RandomIoUCrop(nn.Module):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
if sampler_options is None:
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
self.options = sampler_options
self.trials = trials
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None:
raise ValueError("The targets can't be None for this transform.")
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(
f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions."
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
_, orig_h, orig_w = F.get_dimensions(image)
while True:
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1, )))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
return image, target
for _ in range(self.trials):
# check the aspect ratio limitations
r = self.min_scale + (self.max_scale -
self.min_scale) * torch.rand(2)
new_w = int(orig_w * r[0])
new_h = int(orig_h * r[1])
aspect_ratio = new_w / new_h
if not (self.min_aspect_ratio <= aspect_ratio <=
self.max_aspect_ratio):
continue
# check for 0 area crops
r = torch.rand(2)
left = int((orig_w - new_w) * r[0])
top = int((orig_h - new_h) * r[1])
right = left + new_w
bottom = top + new_h
if left == right or top == bottom:
continue
# check for any valid boxes with centers within the crop area
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
is_within_crop_area = (left < cx) & (cx < right) & (
top < cy) & (cy < bottom)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(
boxes,
torch.tensor([[left, top, right, bottom]],
dtype=boxes.dtype,
device=boxes.device))
if ious.max() < min_jaccard_overlap:
continue
# keep only valid boxes and perform cropping
target["boxes"] = boxes
target["labels"] = target["labels"][is_within_crop_area]
target["boxes"][:, 0::2] -= left
target["boxes"][:, 1::2] -= top
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
image = F.crop(image, top, left, new_h, new_w)
return image, target
[docs]
class RandomZoomOut(nn.Module):
def __init__(self,
fill: Optional[List[float]] = None,
side_range: Tuple[float, float] = (1.0, 4.0),
p: float = 0.5):
super().__init__()
if fill is None:
fill = [0.0, 0.0, 0.0]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(
f"Invalid canvas side range provided {side_range}.")
self.p = p
@torch.jit.unused
def _get_fill_value(self, is_pil):
# type: (bool) -> int
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(
f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions."
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
if torch.rand(1) >= self.p:
return image, target
_, orig_h, orig_w = F.get_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] -
self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
if torch.jit.is_scripting():
fill = 0
else:
fill = self._get_fill_value(F._is_pil_image(image))
image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor):
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
v = torch.tensor(self.fill, device=image.device,
dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (
top + orig_h):, :] = image[..., :, (left + orig_w):] = v
if target is not None:
target["boxes"][:, 0::2] += left
target["boxes"][:, 1::2] += top
return image, target
[docs]
class RandomPhotometricDistort(nn.Module):
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast)
self._hue = T.ColorJitter(hue=hue)
self._saturation = T.ColorJitter(saturation=saturation)
self.p = p
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(
f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions."
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
r = torch.rand(7)
if r[0] < self.p:
image = self._brightness(image)
contrast_before = r[1] < 0.5
if contrast_before:
if r[2] < self.p:
image = self._contrast(image)
if r[3] < self.p:
image = self._saturation(image)
if r[4] < self.p:
image = self._hue(image)
if not contrast_before:
if r[5] < self.p:
image = self._contrast(image)
if r[6] < self.p:
channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image)
if is_pil:
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)
return image, target
[docs]
class ScaleJitter(nn.Module):
"""Randomly resizes the image and its bounding boxes within the specified scale range.
The class implements the Scale Jitter augmentation as described in the paper
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation"
<https://arxiv.org/abs/2012.07177>`_.
Args:
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
range a <= scale <= b.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
"""
def __init__(
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError(
f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions."
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
_, orig_height, orig_width = F.get_dimensions(image)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] -
self.scale_range[0])
r = min(self.target_size[1] / orig_height,
self.target_size[0] / orig_width) * scale
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width],
interpolation=self.interpolation,
antialias=self.antialias)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
)
return image, target
[docs]
class FixedSizeCrop(nn.Module):
def __init__(self, size, fill=0, padding_mode="constant"):
super().__init__()
size = tuple(
T._setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size."
))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode
def _pad(self, img, target, padding):
# Taken from the functional_tensor.py pad
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
padding = [pad_left, pad_top, pad_right, pad_bottom]
img = F.pad(img, padding, self.fill, self.padding_mode)
if target is not None:
target["boxes"][:, 0::2] += pad_left
target["boxes"][:, 1::2] += pad_top
if "masks" in target:
target["masks"] = F.pad(target["masks"], padding, 0,
"constant")
return img, target
def _crop(self, img, target, top, left, height, width):
img = F.crop(img, top, left, height, width)
if target is not None:
boxes = target["boxes"]
boxes[:, 0::2] -= left
boxes[:, 1::2] -= top
boxes[:, 0::2].clamp_(min=0, max=width)
boxes[:, 1::2].clamp_(min=0, max=height)
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:,
3])
target["boxes"] = boxes[is_valid]
target["labels"] = target["labels"][is_valid]
if "masks" in target:
target["masks"] = F.crop(target["masks"][is_valid], top, left,
height, width)
return img, target
[docs]
def forward(self, img, target=None):
_, height, width = F.get_dimensions(img)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
if new_height != height or new_width != width:
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
img, target = self._crop(img, target, top, left, new_height,
new_width)
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
if pad_bottom != 0 or pad_right != 0:
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
return img, target
[docs]
class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size,
int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
[docs]
def forward(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
_, orig_height, orig_width = F.get_dimensions(image)
min_size = self.min_size[torch.randint(len(self.min_size),
(1, )).item()]
r = min(min_size / min(orig_height, orig_width),
self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width],
interpolation=self.interpolation)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width],
interpolation=InterpolationMode.NEAREST)
return image, target
def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
# Random paste targets selection:
num_masks = len(paste_target["masks"])
if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target
# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0,
num_masks, (num_masks, ),
device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)
paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]
masks = target["masks"]
# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image,
size1,
interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks,
size1,
interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]),
device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2,
2).mul(ratios).view(paste_boxes.shape)
paste_alpha_mask = paste_masks.sum(dim=0) > 0
if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)
# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]
# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}
out_target["masks"] = torch.cat([masks, paste_masks])
# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])
# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum(
(-1, -2)).to(torch.float32)
if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)
out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]
if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(
out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
return image, out_target
[docs]
class SimpleCopyPaste(torch.nn.Module):
def __init__(self,
blending=True,
resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending
[docs]
def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple))
and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target,
f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor),
f"Value for the key {k} should be a tensor")
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]
output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []
for image, target, paste_image, paste_target in zip(
images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)
return output_images, output_targets
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s