Skip to content

roll

Shifts elements of an array along specified axis. Elements that roll beyond the last position are circulated and re-introduced at the first.

Parameters:

Name Type Description Default
a COO

Input array

required
shift int or tuple of ints

Number of index positions that elements are shifted. If a tuple is provided, then axis must be a tuple of the same size, and each of the given axes is shifted by the corresponding number. If an int while axis is a tuple of ints, then broadcasting is used so the same shift is applied to all axes.

required
axis int or tuple of ints

Axis or tuple specifying multiple axes. By default, the array is flattened before shifting, after which the original shape is restored.

None

Returns:

Name Type Description
res ndarray

Output array, with the same shape as a.

Source code in sparse/numba_backend/_coo/common.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
def roll(a, shift, axis=None):
    """
    Shifts elements of an array along specified axis. Elements that roll beyond
    the last position are circulated and re-introduced at the first.

    Parameters
    ----------
    a : COO
        Input array
    shift : int or tuple of ints
        Number of index positions that elements are shifted. If a tuple is
        provided, then axis must be a tuple of the same size, and each of the
        given axes is shifted by the corresponding number. If an int while axis
        is a tuple of ints, then broadcasting is used so the same shift is
        applied to all axes.
    axis : int or tuple of ints, optional
        Axis or tuple specifying multiple axes. By default, the
        array is flattened before shifting, after which the original shape is
        restored.

    Returns
    -------
    res : ndarray
        Output array, with the same shape as a.
    """
    from .core import COO, as_coo

    a = as_coo(a)

    # roll flattened array
    if axis is None:
        return roll(a.reshape((-1,)), shift, 0).reshape(a.shape)

    # roll across specified axis
    # parse axis input, wrap in tuple
    axis = normalize_axis(axis, a.ndim)
    if not isinstance(axis, tuple):
        axis = (axis,)

    # make shift iterable
    if not isinstance(shift, Iterable):
        shift = (shift,)

    elif np.ndim(shift) > 1:
        raise ValueError("'shift' and 'axis' must be integers or 1D sequences.")

    # handle broadcasting
    if len(shift) == 1:
        shift = np.full(len(axis), shift)

    # check if dimensions are consistent
    if len(axis) != len(shift):
        raise ValueError("If 'shift' is a 1D sequence, 'axis' must have equal length.")

    if not can_store(a.coords.dtype, max(a.shape + shift)):
        raise ValueError(
            f"cannot roll with coords.dtype {a.coords.dtype} and shift {shift}. Try casting coords to a larger dtype."
        )

    # shift elements
    coords, data = np.copy(a.coords), np.copy(a.data)
    try:
        for sh, ax in zip(shift, axis, strict=True):
            coords[ax] += sh
            coords[ax] %= a.shape[ax]
    except TypeError as e:
        if is_unsigned_dtype(coords.dtype):
            raise ValueError(
                f"rolling with coords.dtype as {coords.dtype} is not safe. Try using a signed dtype."
            ) from e

    return COO(
        coords,
        data=data,
        shape=a.shape,
        has_duplicates=False,
        fill_value=a.fill_value,
    )