Source code for dicaugment.augmentations.utils

from functools import wraps
from typing import Callable, Union

import cv2
import numpy as np
import os
import pydicom as pdm
from typing_extensions import Concatenate, ParamSpec

from dicaugment.core.keypoints_utils import angle_to_2pi_range
from dicaugment.core.transforms_interface import KeypointInternalType

__all__ = [
    "read_dcm_image",
    "MAX_VALUES_BY_DTYPE",
    "MIN_VALUES_BY_DTYPE",
    "NPDTYPE_TO_OPENCV_DTYPE",
    "NPDTYPE_TO_OPENCV_DTYPE",
    "clipped",
    "angle_2pi_range",
    "clip",
    "preserve_shape",
    "preserve_channel_dim",
    "ensure_contiguous",
    "is_rgb_image",
    "is_grayscale_image",
    "is_multispectral_image",
    "get_num_channels",
    "non_rgb_warning",
    "_maybe_process_in_chunks",
    "_maybe_process_by_channel",
]

P = ParamSpec("P")

MAX_VALUES_BY_DTYPE = {
    np.dtype("uint8"): 255,
    np.dtype("uint16"): 65535,
    np.dtype("uint32"): 4294967295,
    np.dtype("float32"): 1.0,
    np.dtype("int16"): 32767,
    np.dtype("int32"): 2147483647,
    np.dtype("float64"): np.finfo(np.float64).max,
}

MIN_VALUES_BY_DTYPE = {
    np.dtype("uint8"): 0,
    np.dtype("uint16"): 0,
    np.dtype("uint32"): 0,
    np.dtype("float32"): 0.0,
    np.dtype("int16"): -32768,
    np.dtype("int32"): -2147483648,
    np.dtype("float64"): np.finfo(np.float64).min,
}


NPDTYPE_TO_OPENCV_DTYPE = {
    np.uint8: cv2.CV_8U,
    np.uint16: cv2.CV_16U,
    np.int32: cv2.CV_32S,
    np.float32: cv2.CV_32F,
    np.float64: cv2.CV_64F,
    np.dtype("uint8"): cv2.CV_8U,
    np.dtype("uint16"): cv2.CV_16U,
    np.dtype("int32"): cv2.CV_32S,
    np.dtype("float32"): cv2.CV_32F,
    np.dtype("float64"): cv2.CV_64F,
}

SCIPY_MODE_TO_NUMPY_MODE = {
    "reflect": "symmetric",
    "constant": "constant",
    "nearest": "edge",
    "mirror": "reflect",
    "wrap": "wrap",
}


[docs] def read_dcm_image(path: str, include_header: bool = True, ends_with: str = ""): """ Reads in an alphabetically sorted series of dcm file types stored in a directory as a `np.ndarray` and optionally a dicom header in a `dict` format. Args: path (str): The filepath to the directory that stores the dcm files. include_header (bool): Whether to return the dicom header metadata associated with the scan. Default: True ends_with (str): If empty string, then all files in directory will be processed. If multiple file types are within the directory, you may filter the results by setting `ends_with=".dcm"` Default: "" Note: `DICOM` object types are dictionaries with the following keys: `PixelSpaxing` (tuple) The space in mm between pixels for both height and width of a slice, respectively `RescaleIntercept` (float) The value to add to each pixel of the scan after scaling with `RescaleSlope` to turn the pixel values of the scan into Hounsfield Units (HU) `RescaleSlope` (float) The value to multiply each pixel of the scan by before adding `RescaleIntercept` to turn the pixel values of the scan into Hounsfield Units (HU) `ConvolutionKernel` (str) A label describing the convolution kernel or algorithm used to reconstruct the data `XRayTubeCurrent` (int) X-Ray Tube Current in mA. See example below: .. code-block:: python dicom = { "PixelSpacing" : (0.5, 0.5), "RescaleIntercept" : -1024.0, "RescaleSlope" : 1.0, "ConvolutionKernel" : 'STANDARD', "XRayTubeCurrent" : 160 } """ if not os.path.isdir(path): raise OSError("{} is not a valid directory".format(path)) img = None for file in sorted(os.listdir(path)): if not file.endswith(ends_with): continue fp = os.path.join(path, file) obj = pdm.dcmread(fp) dcm = np.expand_dims(obj.pixel_array, axis=2).astype(np.int16) if img is None: img = dcm dicom = { "PixelSpacing": tuple(map(float, obj.PixelSpacing)), "RescaleIntercept": float(obj.RescaleIntercept), "RescaleSlope": float(obj.RescaleSlope), "ConvolutionKernel": obj.ConvolutionKernel, "XRayTubeCurrent": int(obj.XRayTubeCurrent), } else: img = np.concatenate([img, dcm], axis=2) if include_header: return img, dicom return img
[docs] def clipped( func: Callable[Concatenate[np.ndarray, P], np.ndarray] ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Decorator method that clips an image to it's specified dtype minimums and maximums""" @wraps(func) def wrapped_function( img: np.ndarray, *args: P.args, **kwargs: P.kwargs ) -> np.ndarray: dtype = img.dtype maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0) minval = MIN_VALUES_BY_DTYPE.get(dtype, 0.0) return clip(func(img, *args, **kwargs), dtype, minval, maxval) return wrapped_function
[docs] def clip(img: np.ndarray, dtype: np.dtype, minval: float, maxval: float) -> np.ndarray: """Clips an image by a minimum and maximum value, then casts to dtype""" return np.clip(img, minval, maxval).astype(dtype)
[docs] def angle_2pi_range( func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType] ) -> Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]: """Decorator method that keeps keypoints angles in the range of [0, 2pi]""" @wraps(func) def wrapped_function( keypoint: KeypointInternalType, *args: P.args, **kwargs: P.kwargs ) -> KeypointInternalType: (x, y, z, a, s) = func(keypoint, *args, **kwargs)[:5] return (x, y, z, angle_to_2pi_range(a), s) return wrapped_function
[docs] def preserve_shape( func: Callable[Concatenate[np.ndarray, P], np.ndarray] ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Decorators that preserves shape of the image""" @wraps(func) def wrapped_function( img: np.ndarray, *args: P.args, **kwargs: P.kwargs ) -> np.ndarray: shape = img.shape result = func(img, *args, **kwargs) result = result.reshape(shape) return result return wrapped_function
[docs] def preserve_channel_dim( func: Callable[Concatenate[np.ndarray, P], np.ndarray] ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Decorator that preserves a dummy channel dim.""" @wraps(func) def wrapped_function( img: np.ndarray, *args: P.args, **kwargs: P.kwargs ) -> np.ndarray: shape = img.shape result = func(img, *args, **kwargs) if len(shape) == 4 and shape[-1] == 1 and len(result.shape) == 3: result = np.expand_dims(result, axis=-1) return result return wrapped_function
[docs] def ensure_contiguous( func: Callable[Concatenate[np.ndarray, P], np.ndarray] ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Decorator that ensures input img is contiguous.""" @wraps(func) def wrapped_function( img: np.ndarray, *args: P.args, **kwargs: P.kwargs ) -> np.ndarray: img = np.require(img, requirements=["C_CONTIGUOUS"]) result = func(img, *args, **kwargs) return result return wrapped_function
[docs] def is_rgb_image(image: np.ndarray) -> bool: """Returns whether image fits the criteria for a volumetric rgb image""" return ( len(image.shape) == 4 and image.shape[-1] == 3 and image.dtype in {np.dtype("uint8"), np.dtype("float32")} )
[docs] def is_grayscale_image(image: np.ndarray) -> bool: """Returns whether image fits the criteria for a volumetric grayscale image""" return (len(image.shape) == 3) or (len(image.shape) == 4 and image.shape[-1] == 1)
def is_uint8_or_float32(image: np.ndarray) -> bool: """Returns whether image is type `uint8` or `float32`""" return image.dtype in {np.dtype("uint8"), np.dtype("float32")}
[docs] def is_multispectral_image(image: np.ndarray) -> bool: """Returns whether image fits the criteria for a volumetric multispectral image""" return len(image.shape) == 4 and image.shape[-1] not in [1, 3]
[docs] def get_num_channels(image: np.ndarray) -> int: """Returns number of channels in image""" return image.shape[3] if len(image.shape) == 4 else 1
[docs] def non_rgb_warning(image: np.ndarray) -> None: """Warns user if image is not an RGB image""" if not is_rgb_image(image): message = "This transformation expects 3-channel images" if is_grayscale_image(image): message += "\nYou can convert your grayscale image to RGB using cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))" if is_multispectral_image( image ): # Any image with a number of channels other than 1 and 3 message += ( "\nThis transformation cannot be applied to multi-spectral images" ) raise ValueError(message)
def _maybe_process_in_chunks( process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs ) -> Callable[[np.ndarray], np.ndarray]: """ Wrap OpenCV function to enable processing images with more than 4 channels. Limitations: This wrapper requires image to be the first argument and rest must be sent via named arguments. Args: process_fn: Transform function (e.g cv2.resize). kwargs: Additional parameters. Returns: numpy.ndarray: Transformed image. """ @wraps(process_fn) def __process_fn(img: np.ndarray) -> np.ndarray: num_channels = get_num_channels(img) if num_channels > 4: chunks = [] for index in range(0, num_channels, 4): if num_channels - index == 2: # Many OpenCV functions cannot work with 2-channel images for i in range(2): chunk = img[:, :, index + i : index + i + 1] chunk = process_fn(chunk, **kwargs) chunk = np.expand_dims(chunk, -1) chunks.append(chunk) else: chunk = img[:, :, index : index + 4] chunk = process_fn(chunk, **kwargs) chunks.append(chunk) img = np.dstack(chunks) else: img = process_fn(img, **kwargs) return img return __process_fn def _maybe_process_by_channel( process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs ) -> Callable[[np.ndarray], np.ndarray]: """ Wrap OpenCV or Scipy function to enable processing channeled images of any length. Limitations: This wrapper requires image to be the first argument and rest must be sent via named arguments. Args: process_fn: Transform function (e.g scipy.ndimage.zoom). kwargs: Additional parameters. Returns: numpy.ndarray: Transformed image. """ @wraps(process_fn) def __process_fn(img: np.ndarray) -> np.ndarray: num_channels = get_num_channels(img) if num_channels > 1 or len(img.shape) > 3: chunks = [] for i in range(num_channels): chunks.append(process_fn(img[..., i], **kwargs)) img = np.stack(chunks, axis=3) else: img = process_fn(img, **kwargs) return img return __process_fn