Skip to content

unstack

Splits an array into a sequence of arrays along the given axis.

Parameters:

Name Type Description Default
x SparseArray

Input sparse arrays.

required
axis int

Axis along which the array will be split

0

Returns:

Name Type Description
out Tuple[SparseArray, ...]

Tuple of slices along the given dimension. All the arrays have the same shape.

Source code in sparse/numba_backend/_common.py
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
def unstack(x, axis=0):
    """
    Splits an array into a sequence of arrays along the given axis.

    Parameters
    ----------
    x : SparseArray
        Input sparse arrays.
    axis : int
        Axis along which the array will be split

    Returns
    -------
    out : Tuple[SparseArray,...]
        Tuple of slices along the given dimension. All the arrays have the same shape.
    """
    ndim = x.ndim

    if not (-ndim <= axis < ndim):
        raise ValueError(f"axis must be in range [-{ndim}, {ndim}), got {axis}")

    if not isinstance(x, SparseArray):
        raise TypeError("`a` must be a SparseArray.")

    if axis < 0:
        axis = ndim + axis
    new_order = (axis,) + tuple(i for i in range(ndim) if i != axis)
    x = x.transpose(new_order)
    return (*x,)