array_api_extra.partition

array_api_extra.partition(a, kth, /, axis=-1, *, xp=None)

Return a partitioned copy of an array.

Creates a copy of the array and partially sorts it in such a way that the value of the element in k-th position is in the position it would be in a sorted array. In the output array, all elements smaller than the k-th element are located to the left of this element and all equal or greater are located to its right. The ordering of the elements in the two partitions on the either side of the k-th element in the output array is undefined.

Parameters:
  • a (Array) – Input array.

  • kth (int) – Element index to partition by.

  • axis (int, optional) – Axis along which to partition. The default is -1 (the last axis). If None, the flattened array is used.

  • xp (array_namespace, optional) – The standard-compatible namespace for x. Default: infer.

Returns:

Array of the same type and shape as a.

Return type:

partitioned_array

Notes

If xp implements partition or an equivalent function (e.g. topk for torch), complexity will likely be O(n). If not, this function simply calls xp.sort and complexity is O(n log n).