"""
Module that provides convenience functions for converting data between
array layouts.
.. versionadded:: 0.4.0
"""
from ._ndarray_backend import xp
def get_allowed_array_layouts() -> list:
return [
"rgb",
"rgba",
"3d",
"2d",
]
[docs]
def convert_data_to_3d_array_layout(
data: xp.ndarray) -> tuple[xp.ndarray, str]:
"""Convert the data to the 3d array_layout
Returns
-------
data_out
3d version of the data
array_layout
original array layout for future reference
Notes
-----
If input is either a RGB or RGBA array layout as input, the first
channel is taken as the image to process. In other words, it is assumed
that all channels contain the same information, so the first channel is
used. 3D RGB/RGBA array layouts, such as (50, 256, 256, 3), are not allowed
(yet).
"""
if len(data.shape) == 3:
if data.shape[-1] in [1, 2, 3]:
# take the first slice (we have alpha or RGB information)
data, array_layout = _convert_rgb_to_3d(data)
elif data.shape[-1] == 4:
# take the first slice (we have alpha or RGB information)
data, array_layout = _convert_rgba_to_3d(data)
else:
# we have a 3D image stack (z, y, x)
data, array_layout = data, "3d"
elif len(data.shape) == 2:
# we have a 2D image (y, x). convert to (z, y, z)
data, array_layout = _convert_2d_to_3d(data)
else:
raise ValueError(f"data_input shape must be 2d or 3d, "
f"got shape {data.shape}.")
return data.copy(), array_layout
[docs]
def convert_3d_data_to_array_layout(
data: xp.ndarray, array_layout: str) -> xp.ndarray:
"""Convert the 3d data to the desired `array_layout`.
Returns
-------
data_out : xp.ndarray
input `data` with the given `array layout`
Notes
-----
Currently, this function is limited to converting from 3d to other
array layouts. Perhaps if there is demand in the future,
this can be generalised for other conversions.
"""
assert array_layout in get_allowed_array_layouts(), (
f"`array_layout` not allowed. "
f"Allowed layouts are: {get_allowed_array_layouts()}.")
assert len(data.shape) == 3, (
f"The data should be 3d, got {len(data.shape)=}")
data = data.copy()
if array_layout == "rgb":
data = _convert_3d_to_rgb(data)
elif array_layout == "rgba":
data = _convert_3d_to_rgba(data)
elif array_layout == "3d":
data = data
else:
data = _convert_3d_to_2d(data)
return data
def _convert_rgb_to_3d(data_input: xp.ndarray) -> tuple[xp.ndarray, str]:
data = data_input[:, :, 0]
data = data[xp.newaxis, :, :]
array_layout = "rgb"
return data, array_layout
def _convert_rgba_to_3d(data_input: xp.ndarray) -> tuple[xp.ndarray, str]:
data, _ = _convert_rgb_to_3d(data_input)
array_layout = "rgba"
return data, array_layout
def _convert_2d_to_3d(data_input: xp.ndarray) -> tuple[xp.ndarray, str]:
data = data_input[xp.newaxis, :, :]
array_layout = "2d"
return data, array_layout
def _convert_3d_to_rgb(data_input: xp.ndarray) -> xp.ndarray:
data = data_input[0]
data = xp.dstack((data, data, data))
return data
def _convert_3d_to_rgba(data_input: xp.ndarray) -> xp.ndarray:
data = data_input[0]
data = xp.dstack((data, data, data, xp.ones_like(data)))
return data
def _convert_3d_to_2d(data_input: xp.ndarray) -> xp.ndarray:
return data_input[0]