import random
from enum import Enum
from typing import Dict, Optional, Sequence, Tuple, Union, Any
import cv2
import numpy as np
from ...core.bbox_utils import denormalize_bbox, normalize_bbox
from ... import random_utils
from ...core.transforms_interface import (
BoxInternalType,
DualTransform,
ImageColorType,
KeypointInternalType,
ScaleFloatType,
DicomType,
to_tuple,
INTER_LINEAR,
INTER_NEAREST,
)
from ..functional import bbox_from_mask
from . import functional as F
from ..dicom import functional as Fdicom
__all__ = [
"ShiftScaleRotate",
"VerticalFlip",
"HorizontalFlip",
"SliceFlip",
"Flip",
"Transpose",
"PadIfNeeded",
]
[docs]
class ShiftScaleRotate(DualTransform):
"""Randomly apply affine transforms: translate, scale and rotate the input.
Args:
shift_limit ((float, float) or float): shift factor range for height, width, and depth. If shift_limit
is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and
upper bounds should lie in range [0, 1]. Default: (-0.0625, 0.0625).
scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
Default: (-0.1, 0.1).
rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the
range will be (-rotate_limit, rotate_limit). Default: (-45, 45).
axes (str, list of str): Defines the axis of rotation. Must be one of `{'xy','yz','xz'}` or a list of them.
If a single str is passed, then all rotations will occur on that axis
If a list is passed, then one axis of rotation will be chosen at random for each call of the transformation.
Default: "xy"
interpolation: scipy interpolation method (e.g. dicaugment.INTER_NEAREST). default: dicaugment.INTER_LINEAR
border_mode (str): Scipy parameter to determine how the input image is extended during convolution to maintain image shape. Must be one of the following:
* `reflect` (d c b a | a b c d | d c b a): The input is extended by reflecting about the edge of the last pixel. This mode is also sometimes referred to as half-sample symmetric.
* `constant` (k k k k | a b c d | k k k k): The input is extended by filling all values beyond the edge with the same constant value, defined by the cval parameter.
* `nearest` (a a a a | a b c d | d d d d): The input is extended by replicating the last pixel.
* `mirror` (d c b | a b c d | c b a): The input is extended by reflecting about the center of the last pixel. This mode is also sometimes referred to as whole-sample symmetric.
* `wrap` (a b c d | a b c d | a b c d): The input is extended by wrapping around to the opposite edge.
See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.median_filter.html
Default: `constant`.
value (int or float): padding value if border_mode is "constant".
mask_value (int or float): padding value if border_mode is "constant" applied for masks.
crop_to_border (bool): If True, then the image is padded or cropped to fit the entire rotation. If False, then original image shape is
maintained and some portions of the image may be cropped away. Note that any translations are applied after the image is reshaped.
Default: False
shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value
instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value,
the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in
the range [0, 1]. Default: None.
shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value
instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value,
the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie
in the range [0, 1]. Default: None.
shift_limit_z ((float, float) or float): shift factor range for depth. If it is set then this value
instead of shift_limit will be used for shifting depth. If shift_limit_z is a single float value,
the range will be (-shift_limit_z, shift_limit_z). Absolute values for lower and upper bounds should lie
in the range [0, 1]. Default: None.
rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
Default: "largest_box"
always_apply (bool): whether to always apply the transformation. Default: False
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, keypoints
Image types:
uint8, uint16, int16, float32
"""
def __init__(
self,
shift_limit: float = 0.0625,
scale_limit: float = 0.1,
rotate_limit: Union[int, float] = 45,
axes: str = "xy",
interpolation: int = INTER_LINEAR,
border_mode: str = "constant",
crop_to_border: bool = False,
value: Union[int, float] = 0,
mask_value: Union[int, float] = 0,
shift_limit_x: Optional[Union[Tuple[float, float], float]] = None,
shift_limit_y: Optional[Union[Tuple[float, float], float]] = None,
shift_limit_z: Optional[Union[Tuple[float, float], float]] = None,
rotate_method: str = "largest_box",
always_apply=False,
p=0.5,
):
super(ShiftScaleRotate, self).__init__(always_apply, p)
self.shift_limit_x = to_tuple(
shift_limit_x if shift_limit_x is not None else shift_limit
)
self.shift_limit_y = to_tuple(
shift_limit_y if shift_limit_y is not None else shift_limit
)
self.shift_limit_z = to_tuple(
shift_limit_z if shift_limit_z is not None else shift_limit
)
self.scale_limit = to_tuple(scale_limit, bias=1.0)
self.rotate_limit = to_tuple(rotate_limit)
self.interpolation = interpolation
self.crop_to_border = crop_to_border
self.border_mode = border_mode
self.value = value
self.axes = (axes,) if isinstance(axes, str) else axes
self.mask_value = mask_value
self.rotate_method = rotate_method
if self.rotate_method not in ["largest_box", "ellipse"]:
raise ValueError(f"Rotation method {self.rotate_method} is not valid.")
if len(set(self.axes).difference({"xy", "yz", "xz"})) != 0:
raise ValueError(
"Parameter axes contains one or more elements that are not allowed. Got {}".format(
set(self.axes).difference({"xy", "yz", "xz"})
)
)
[docs]
def apply(
self,
img: np.ndarray,
angle: float = 0,
axes: str = "xy",
scale: float = 0,
dx: float = 0,
dy: float = 0,
dz: float = 0,
interpolation: int = INTER_LINEAR,
**params,
) -> np.ndarray:
"""Applies the transformation to the image"""
return F.shift_scale_rotate(
img,
angle,
scale,
dx,
dy,
dz,
axes,
self.crop_to_border,
interpolation,
self.border_mode,
self.value,
)
[docs]
def apply_to_mask(
self,
img: np.ndarray,
angle: float = 0,
axes: str = "xy",
scale: float = 0,
dx: float = 0,
dy: float = 0,
dz: float = 0,
**params,
) -> np.ndarray:
"""Applies the transformation to a mask and forces INTER_NEAREST interpolation"""
return F.shift_scale_rotate(
img,
angle,
scale,
dx,
dy,
dz,
axes,
self.crop_to_border,
INTER_NEAREST,
self.border_mode,
self.mask_value,
)
[docs]
def apply_to_keypoint(
self,
keypoint: KeypointInternalType,
angle: float = 0,
axes: str = "xy",
scale: float = 0,
dx: float = 0,
dy: float = 0,
dz: float = 0,
rows: int = 0,
cols: int = 0,
slices: int = 0,
**params,
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_shift_scale_rotate(
keypoint,
angle,
scale,
dx,
dy,
dz,
axes,
self.crop_to_border,
rows,
cols,
slices,
)
[docs]
def apply_to_dicom(self, dicom: DicomType, scale: float = 1, **params) -> DicomType:
"""Applies the augmentation to a dicom type"""
return Fdicom.dicom_scale(dicom, scale, scale)
[docs]
def get_params(self) -> Dict[str, Any]:
"""Returns parameters needed for the `apply` methods"""
return {
"angle": random.uniform(self.rotate_limit[0], self.rotate_limit[1]),
"scale": random.uniform(self.scale_limit[0], self.scale_limit[1]),
"dx": random.uniform(self.shift_limit_x[0], self.shift_limit_x[1]),
"dy": random.uniform(self.shift_limit_y[0], self.shift_limit_y[1]),
"dz": random.uniform(self.shift_limit_z[0], self.shift_limit_z[1]),
"axes": random.choice(self.axes),
}
[docs]
def apply_to_bbox(
self,
bbox: BoxInternalType,
angle: float = 0,
axes: str = "xy",
scale: float = 0,
dx: float = 0,
dy: float = 0,
dz: float = 0,
**params,
) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_shift_scale_rotate(
bbox,
angle,
scale,
dx,
dy,
dz,
axes,
self.crop_to_border,
self.rotate_method,
**params,
)
[docs]
class PadIfNeeded(DualTransform):
"""Pad side of the image / max if side is less than desired number.
Args:
min_height (int): minimal result image height.
min_width (int): minimal result image width.
pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument.
pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument.
position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or
PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT.
or PositionType.RANDOM. Default: PositionType.CENTER.
border_mode (OpenCV flag): OpenCV border mode.
value (int, float, list of int, list of float): padding value if border_mode is "constant".
mask_value (int, float, list of int, list of float): padding value for mask if border_mode is "constant".
always_apply (bool): whether to always apply the transformation. Default: False
p (float): probability of applying the transform. Default: 1.0.
Targets:
image, mask, bbox, keypoints
Image types:
uint8, float32
"""
[docs]
class PositionType(Enum):
CENTER = "center"
FRONT_TOP_LEFT = "front_top_left"
FRONT_TOP_RIGHT = "front_top_right"
FRONT_BOTTOM_LEFT = "front_bottom_left"
FRONT_BOTTOM_RIGHT = "front_bottom_right"
BACK_TOP_LEFT = "back_top_left"
BACK_TOP_RIGHT = "back_top_right"
BACK_BOTTOM_LEFT = "back_bottom_left"
BACK_BOTTOM_RIGHT = "back_bottom_right"
RANDOM = "random"
def __init__(
self,
min_height: Optional[int] = 1024,
min_width: Optional[int] = 1024,
min_depth: Optional[int] = 1024,
pad_height_divisor: Optional[int] = None,
pad_width_divisor: Optional[int] = None,
pad_depth_divisor: Optional[int] = None,
position: Union[PositionType, str] = PositionType.CENTER,
border_mode: int = "constant",
value: Union[float, int] = 0,
mask_value: Union[float, int] = 0,
always_apply: bool = False,
p: float = 1.0,
):
if (min_height is None) == (pad_height_divisor is None):
raise ValueError(
"Only one of 'min_height' and 'pad_height_divisor' parameters must be set"
)
if (min_width is None) == (pad_width_divisor is None):
raise ValueError(
"Only one of 'min_width' and 'pad_width_divisor' parameters must be set"
)
if (min_depth is None) == (pad_depth_divisor is None):
raise ValueError(
"Only one of 'min_depth' and 'pad_depth_divisor' parameters must be set"
)
super(PadIfNeeded, self).__init__(always_apply, p)
self.min_height = min_height
self.min_width = min_width
self.min_depth = min_depth
self.pad_height_divisor = pad_height_divisor
self.pad_width_divisor = pad_width_divisor
self.pad_depth_divisor = pad_depth_divisor
self.position = PadIfNeeded.PositionType(position)
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value
[docs]
def update_params(self, params, **kwargs):
params = super(PadIfNeeded, self).update_params(params, **kwargs)
rows = params["rows"]
cols = params["cols"]
slices = params["slices"]
if self.min_height is not None:
if rows < self.min_height:
h_pad_top = int((self.min_height - rows) / 2.0)
h_pad_bottom = self.min_height - rows - h_pad_top
else:
h_pad_top = 0
h_pad_bottom = 0
else:
pad_remained = rows % self.pad_height_divisor
pad_rows = self.pad_height_divisor - pad_remained if pad_remained > 0 else 0
h_pad_top = pad_rows // 2
h_pad_bottom = pad_rows - h_pad_top
if self.min_width is not None:
if cols < self.min_width:
w_pad_left = int((self.min_width - cols) / 2.0)
w_pad_right = self.min_width - cols - w_pad_left
else:
w_pad_left = 0
w_pad_right = 0
else:
pad_remainder = cols % self.pad_width_divisor
pad_cols = (
self.pad_width_divisor - pad_remainder if pad_remainder > 0 else 0
)
w_pad_left = pad_cols // 2
w_pad_right = pad_cols - w_pad_left
if self.min_depth is not None:
if slices < self.min_depth:
d_pad_front = int((self.min_depth - slices) / 2.0)
d_pad_back = self.min_depth - slices - d_pad_front
else:
d_pad_front = 0
d_pad_back = 0
else:
pad_remained = slices % self.pad_depth_divisor
pad_slices = (
self.pad_depth_divisor - pad_remained if pad_remained > 0 else 0
)
d_pad_front = pad_slices // 2
d_pad_back = pad_slices - d_pad_front
(
h_pad_top,
h_pad_bottom,
w_pad_left,
w_pad_right,
d_pad_front,
d_pad_back,
) = self.__update_position_params(
h_top=h_pad_top,
h_bottom=h_pad_bottom,
w_left=w_pad_left,
w_right=w_pad_right,
d_front=d_pad_front,
d_back=d_pad_back,
)
params.update(
{
"pad_top": h_pad_top,
"pad_bottom": h_pad_bottom,
"pad_left": w_pad_left,
"pad_right": w_pad_right,
"pad_front": d_pad_front,
"pad_back": d_pad_back,
}
)
return params
[docs]
def apply(
self,
img: np.ndarray,
pad_top: int = 0,
pad_bottom: int = 0,
pad_left: int = 0,
pad_right: int = 0,
pad_front: int = 0,
pad_back: int = 0,
**params,
) -> np.ndarray:
"""Applies the transformation to the image"""
return F.pad_with_params(
img,
pad_top,
pad_bottom,
pad_left,
pad_right,
pad_front,
pad_back,
border_mode=self.border_mode,
value=self.value,
)
[docs]
def apply_to_mask(
self,
img: np.ndarray,
pad_top: int = 0,
pad_bottom: int = 0,
pad_left: int = 0,
pad_right: int = 0,
pad_front: int = 0,
pad_back: int = 0,
**params,
) -> np.ndarray:
"""Applies the transformation to a mask"""
return F.pad_with_params(
img,
pad_top,
pad_bottom,
pad_left,
pad_right,
pad_front,
pad_back,
border_mode=self.border_mode,
value=self.mask_value,
)
[docs]
def apply_to_bbox(
self,
bbox: BoxInternalType,
pad_top: int = 0,
pad_bottom: int = 0,
pad_left: int = 0,
pad_right: int = 0,
pad_front: int = 0,
pad_back: int = 0,
rows: int = 0,
cols: int = 0,
slices: int = 0,
**params,
) -> BoxInternalType:
"""Applies the transformation to a bbox"""
x_min, y_min, z_min, x_max, y_max, z_max = denormalize_bbox(
bbox, rows, cols, slices
)[:6]
bbox = (
x_min + pad_left,
y_min + pad_top,
z_min + pad_front,
x_max + pad_left,
y_max + pad_top,
z_max + pad_front,
)
return normalize_bbox(
bbox,
rows + pad_top + pad_bottom,
cols + pad_left + pad_right,
slices + pad_front + pad_back,
)
[docs]
def apply_to_keypoint(
self,
keypoint: KeypointInternalType,
pad_top: int = 0,
pad_bottom: int = 0,
pad_left: int = 0,
pad_right: int = 0,
pad_front: int = 0,
pad_back: int = 0,
**params,
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
x, y, z, angle, scale = keypoint[:5]
return x + pad_left, y + pad_top, z + pad_front, angle, scale
def __update_position_params(
self,
h_top: int,
h_bottom: int,
w_left: int,
w_right: int,
d_front: int,
d_back: int,
) -> Tuple[int, int, int, int]:
if self.position == PadIfNeeded.PositionType.FRONT_TOP_LEFT:
h_bottom += h_top
w_right += w_left
d_back += d_front
h_top = 0
w_left = 0
d_front = 0
elif self.position == PadIfNeeded.PositionType.BACK_TOP_LEFT:
h_bottom += h_top
w_right += w_left
d_front += d_back
h_top = 0
w_left = 0
d_back = 0
elif self.position == PadIfNeeded.PositionType.FRONT_TOP_RIGHT:
h_bottom += h_top
w_left += w_right
d_back += d_front
h_top = 0
w_right = 0
d_front = 0
elif self.position == PadIfNeeded.PositionType.BACK_TOP_RIGHT:
h_bottom += h_top
w_left += w_right
d_front += d_back
h_top = 0
w_right = 0
d_back = 0
elif self.position == PadIfNeeded.PositionType.FRONT_BOTTOM_LEFT:
h_top += h_bottom
w_right += w_left
d_back += d_front
h_bottom = 0
w_left = 0
d_front = 0
elif self.position == PadIfNeeded.PositionType.BACK_BOTTOM_LEFT:
h_top += h_bottom
w_right += w_left
d_front += d_back
h_bottom = 0
w_left = 0
d_back = 0
elif self.position == PadIfNeeded.PositionType.FRONT_BOTTOM_RIGHT:
h_top += h_bottom
w_left += w_right
d_back += d_front
h_bottom = 0
w_right = 0
d_front = 0
elif self.position == PadIfNeeded.PositionType.BACK_BOTTOM_RIGHT:
h_top += h_bottom
w_left += w_right
d_front += d_back
h_bottom = 0
w_right = 0
d_back = 0
elif self.position == PadIfNeeded.PositionType.RANDOM:
h_pad = h_top + h_bottom
w_pad = w_left + w_right
d_pad = d_front + d_back
h_top = random.randint(0, h_pad)
h_bottom = h_pad - h_top
w_left = random.randint(0, w_pad)
w_right = w_pad - w_left
d_front = random.randint(0, d_pad)
d_back = d_pad - d_front
return h_top, h_bottom, w_left, w_right, d_front, d_back
[docs]
class VerticalFlip(DualTransform):
"""Flip the input vertically around the x-axis.
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""
[docs]
def apply(self, img: np.ndarray, **params) -> np.ndarray:
"""Applies the transformation to the image"""
return F.vflip(img)
[docs]
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_vflip(bbox, **params)
[docs]
def apply_to_keypoint(
self, keypoint: KeypointInternalType, **params
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_vflip(keypoint, **params)
[docs]
class HorizontalFlip(DualTransform):
"""Flip the input horizontally around the y-axis.
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""
[docs]
def apply(self, img: np.ndarray, **params) -> np.ndarray:
"""Applies the transformation to the image"""
return F.hflip(img)
[docs]
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_hflip(bbox, **params)
[docs]
def apply_to_keypoint(
self, keypoint: KeypointInternalType, **params
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_hflip(keypoint, **params)
[docs]
class SliceFlip(DualTransform):
"""Flip the input along the slice dimension.
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""
[docs]
def apply(self, img: np.ndarray, **params) -> np.ndarray:
"""Applies the transformation to the image"""
return F.zflip(img)
[docs]
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_zflip(bbox, **params)
[docs]
def apply_to_keypoint(
self, keypoint: KeypointInternalType, **params
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_zflip(keypoint, **params)
[docs]
class Flip(DualTransform):
"""Flip the input either horizontally, vertically, along the z-axis, or all.
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""
[docs]
def apply(self, img: np.ndarray, d: int = 0, **params) -> np.ndarray:
"""
Applies the transformation to the image
Args:
img (np.ndarray): an image
d (int): code that specifies how to flip the input. 0 for vertical flipping, 1 for horizontal flipping,
2 for z-axis flip, or -1 for vertical, horizontal, and z-axis flipping
"""
return F.random_flip(img, d)
[docs]
def get_params(self):
"""Returns parameters needed for the `apply` methods"""
# Random int in the range [-1, 1]
return {"d": random.randint(-1, 2)}
[docs]
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_flip(bbox, **params)
[docs]
def apply_to_keypoint(
self, keypoint: KeypointInternalType, **params
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_flip(keypoint, **params)
[docs]
class Transpose(DualTransform):
"""Transpose the input by swapping rows and columns. Slice dimension remains unaffected
Args:
p (float): probability of applying the transform. Default: 0.5.
Targets:
image, mask, bboxes, keypoints
Image types:
uint8, float32
"""
[docs]
def apply(self, img: np.ndarray, **params) -> np.ndarray:
"""Applies the transformation to the image"""
return F.transpose(img)
[docs]
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
"""Applies the transformation to a bbox"""
return F.bbox_transpose(bbox, 0, **params)
[docs]
def apply_to_keypoint(
self, keypoint: KeypointInternalType, **params
) -> KeypointInternalType:
"""Applies the transformation to a keypoint"""
return F.keypoint_transpose(keypoint)
[docs]
def apply_to_dicom(self, dicom: DicomType, **params) -> DicomType:
"""Applies the augmentation to a dicom type"""
return Fdicom.transpose_dicom(dicom)