array_api_extra.one_hot¶
- array_api_extra.one_hot(x, /, num_classes, *, dtype=None, axis=-1, xp=None)¶
One-hot encode the given indices.
Each index in the input x is encoded as a vector of zeros of length num_classes with the element at the given index set to one.
- Parameters:
x (array) – An array with integral dtype whose values are between 0 and num_classes - 1.
num_classes (int) – Number of classes in the one-hot dimension.
dtype (DType, optional) – The dtype of the return value. Defaults to the default float dtype (usually float64).
axis (int, optional) – Position in the expanded axes where the new axis is placed. Default: -1.
xp (array_namespace, optional) – The standard-compatible namespace for x. Default: infer.
- Returns:
An array having the same shape as x except for a new axis at the position given by axis having size num_classes. If axis is unspecified, it defaults to -1, which appends a new axis.
If
x < 0
orx >= num_classes
, then the result is undefined, may raise an exception, or may even cause a bad state. x is not checked.- Return type:
array
Examples
>>> import array_api_extra as xpx >>> import array_api_strict as xp >>> xpx.one_hot(xp.asarray([1, 2, 0]), 3) Array([[0., 1., 0.], [0., 0., 1.], [1., 0., 0.]], dtype=array_api_strict.float64)