Skip to content

moveaxis

Move axes of an array to new positions.

Other axes remain in their original order.

Parameters:

Name Type Description Default
a SparseArray

The array whose axes should be reordered.

required
source int or List[int]

Original positions of the axes to move. These must be unique.

required
destination int or List[int]

Destination positions for each of the original axes. These must also be unique.

required

Returns:

Type Description
SparseArray

Array with moved axes.

Examples:

>>> import numpy as np
>>> import sparse
>>> x = sparse.COO.from_numpy(np.ones((2, 3, 4, 5)))
>>> sparse.moveaxis(x, (0, 1), (2, 3))
<COO: shape=(4, 5, 2, 3), dtype=float64, nnz=120, fill_value=0.0>
Source code in sparse/numba_backend/_common.py
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
def moveaxis(a, source, destination):
    """
    Move axes of an array to new positions.

    Other axes remain in their original order.

    Parameters
    ----------
    a : SparseArray
        The array whose axes should be reordered.
    source : int or List[int]
        Original positions of the axes to move. These must be unique.
    destination : int or List[int]
        Destination positions for each of the original axes. These must also be unique.

    Returns
    -------
    SparseArray
        Array with moved axes.

    Examples
    --------
    >>> import numpy as np
    >>> import sparse
    >>> x = sparse.COO.from_numpy(np.ones((2, 3, 4, 5)))
    >>> sparse.moveaxis(x, (0, 1), (2, 3))
    <COO: shape=(4, 5, 2, 3), dtype=float64, nnz=120, fill_value=0.0>
    """

    if not isinstance(source, Iterable):
        source = (source,)
    if not isinstance(destination, Iterable):
        destination = (destination,)

    source = normalize_axis(source, a.ndim)
    destination = normalize_axis(destination, a.ndim)

    if len(source) != len(destination):
        raise ValueError("`source` and `destination` arguments must have the same number of elements")

    order = [n for n in range(a.ndim) if n not in source]

    for dest, src in sorted(zip(destination, source, strict=True)):
        order.insert(dest, src)

    return a.transpose(order)