array_api_extra.expand_dims¶
- array_api_extra.expand_dims(a, /, *, axis=(0,), xp)¶
Expand the shape of an array.
Insert (a) new axis/axes that will appear at the position(s) specified by axis in the expanded array shape.
This is
xp.expand_dims
for axis an int or a tuple of ints. Roughly equivalent tonumpy.expand_dims
for NumPy arrays.- Parameters:
a (array)
axis (int or tuple of ints, optional) – Position(s) in the expanded axes where the new axis (or axes) is/are placed. If multiple positions are provided, they should be unique (note that a position given by a positive index could also be referred to by a negative index - that will also result in an error). Default:
(0,)
.xp (array_namespace) – The standard-compatible namespace for a.
- Returns:
res – a with an expanded shape.
- Return type:
array
Examples
>>> import array_api_strict as xp >>> import array_api_extra as xpx >>> x = xp.asarray([1, 2]) >>> x.shape (2,)
The following is equivalent to
x[xp.newaxis, :]
orx[xp.newaxis]
:>>> y = xpx.expand_dims(x, axis=0, xp=xp) >>> y Array([[1, 2]], dtype=array_api_strict.int64) >>> y.shape (1, 2)
The following is equivalent to
x[:, xp.newaxis]
:>>> y = xpx.expand_dims(x, axis=1, xp=xp) >>> y Array([[1], [2]], dtype=array_api_strict.int64) >>> y.shape (2, 1)
axis
may also be a tuple:>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) >>> y Array([[[1, 2]]], dtype=array_api_strict.int64)
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) >>> y Array([[[1], [2]]], dtype=array_api_strict.int64)