public interface Tensor<N,T extends Tensor<N,T>>
Modifier and Type | Interface and Description |
---|---|
static interface |
Tensor.FlattenedView<N> |
Modifier and Type | Field and Description |
---|---|
static long[] |
ONE_BY_ONE_SHAPE |
static long[] |
SCALAR_SHAPE |
static long[] |
SCALAR_STRIDE |
Modifier and Type | Method and Description |
---|---|
N[] |
asFlatArray() |
default java.util.List<N> |
asFlatList() |
T |
broadcast(long... toShape) |
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> |
create(DATA[] data,
long[] shape) |
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> |
createFilled(DATA data,
long[] shape) |
T |
diag() |
T |
duplicate() |
BooleanTensor |
elementwiseEquals(N value) |
default BooleanTensor |
elementwiseEquals(Tensor that) |
static BooleanTensor |
elementwiseEquals(Tensor a,
Tensor b) |
default T |
expandDims(int axis) |
T |
get(BooleanTensor booleanIndex) |
Tensor.FlattenedView<N> |
getFlattenedView() |
long |
getLength() |
int |
getRank() |
long[] |
getShape() |
long[] |
getStride()
Returns the stride for each dimension of the tensor (based on C ordering).
|
default N |
getValue(long... index)
getValue returns a single primitive value from a specified index.
|
default boolean |
hasSameShapeAs(long[] shape) |
default boolean |
hasSameShapeAs(Tensor that) |
default boolean |
isLengthOne() |
default boolean |
isMatrix() |
default boolean |
isScalar() |
default boolean |
isVector()
Returns true if the tensor is a vector.
|
default T |
moveAxis(int source,
int destination) |
T |
permute(int... rearrange) |
T |
reshape(long... newShape) |
default N |
scalar() |
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> |
scalar(DATA data) |
default void |
setValue(N value,
long... index) |
T |
slice(int dimension,
long index) |
T |
slice(Slicer slicer) |
default T |
slice(java.lang.String sliceArg) |
default java.util.List<T> |
sliceAlongDimension(int dimension,
long indexStart,
long indexEnd) |
java.util.List<T> |
split(int dimension,
long... splitAtIndices) |
default T |
squeeze() |
default T |
swapAxis(int axis1,
int axis2) |
T |
take(long... index) |
default T |
transpose() |
static final long[] SCALAR_SHAPE
static final long[] SCALAR_STRIDE
static final long[] ONE_BY_ONE_SHAPE
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> TENSOR scalar(DATA data)
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> TENSOR createFilled(DATA data, long[] shape)
static <DATA,TENSOR extends Tensor<DATA,TENSOR>> TENSOR create(DATA[] data, long[] shape)
static BooleanTensor elementwiseEquals(Tensor a, Tensor b)
int getRank()
long[] getShape()
long[] getStride()
The stride is the distance you'd move in a flat representation of the tensor for each index within that dimension EG) For a 2x2 Tensor the Tensor would be laid out (in C order): [{0, 0}, {0, 1}, {1, 0}, {1, 1}] Thus the stride array would be provided as: [2, 1]
long getLength()
default N getValue(long... index)
index
- the index of the scalar value.T get(BooleanTensor booleanIndex)
booleanIndex
- a boolean tensor the same shape as this tensor where true is specified if the element
should be kept and false if not.default void setValue(N value, long... index)
default N scalar()
T duplicate()
T slice(int dimension, long index)
default T slice(java.lang.String sliceArg)
T take(long... index)
java.util.List<T> split(int dimension, long... splitAtIndices)
dimension
- the dimension to split onsplitAtIndices
- the indices that the dimension to split on should be split one.g A = [ 1, 2, 3, 4, 5, 6 7, 8, 9, 1, 2, 3 ]
A.split(0, [1]) gives List([1, 2, 3, 4, 5, 6]) A.split(0, [1, 2]) gives List([1, 2, 3, 4, 5, 6], [7, 8, 9, 1, 2, 3]
A.split(1, [1, 3, 6]) gives List( [1, [2, 3 , [4, 5, 6, 7] 8, 9] 1, 2, 3] )
default java.util.List<T> sliceAlongDimension(int dimension, long indexStart, long indexEnd)
T diag()
default T transpose()
N[] asFlatArray()
T reshape(long... newShape)
default T squeeze()
default T expandDims(int axis)
default T moveAxis(int source, int destination)
default T swapAxis(int axis1, int axis2)
T permute(int... rearrange)
T broadcast(long... toShape)
Tensor.FlattenedView<N> getFlattenedView()
default java.util.List<N> asFlatList()
default boolean isLengthOne()
default boolean isScalar()
default boolean isVector()
(1, 2, 3) is a 1x3 vector.
(1) (2) (3) is a 3x1 vector.
default boolean isMatrix()
default boolean hasSameShapeAs(Tensor that)
default boolean hasSameShapeAs(long[] shape)
default BooleanTensor elementwiseEquals(Tensor that)
BooleanTensor elementwiseEquals(N value)