public class TensorShapeValidation
extends java.lang.Object
Modifier and Type | Method and Description |
---|---|
static long[] |
checkAllShapesMatch(java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(long[]... shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
long[]... shapes) |
static void |
checkDimensionExistsInShape(int dimension,
long[] shape)
Check if the given dimension exists within the shape
|
static long[] |
checkHasSingleNonScalarShapeOrAllScalar(long[]... shapes)
This ensures there is at most a single non-scalar shape.
|
static void |
checkIndexIsValid(long[] shape,
long... index) |
static void |
checkRankIsAtLeastTwo(long[] shape) |
static void |
checkShapeIsSquareMatrix(long[] shape) |
static long[] |
checkShapesCanBeConcatenated(int dimension,
long[]... shapes) |
static void |
checkTensorsMatchNonScalarShapeOrAreScalar(long[] proposalShape,
long[]... shapes)
This is a common function to check that tensors are either
the same shape of the proposal in question OR scalar.
|
public static void checkTensorsMatchNonScalarShapeOrAreScalar(long[] proposalShape, long[]... shapes)
proposalShape
- the tensor shape being validatedshapes
- the tensors being validated againstjava.lang.IllegalArgumentException
- if there is more than one non-scalar shape OR if the non-scalar shape does
not match the proposal shape.public static void checkDimensionExistsInShape(int dimension, long[] shape)
dimension
- Proposed dimensionshape
- Shape to checkjava.lang.IllegalArgumentException
- if the dimension exceeds the rank of the shapepublic static long[] checkHasSingleNonScalarShapeOrAllScalar(long[]... shapes)
shapes
- the tensors for shape checkingjava.lang.IllegalArgumentException
- if there is more than one non-scalar shapepublic static void checkShapeIsSquareMatrix(long[] shape)
public static long[] checkAllShapesMatch(long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage, long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage, java.util.Collection<long[]> shapes)
public static long[] checkAllShapesMatch(java.util.Collection<long[]> shapes)
public static long[] checkShapesCanBeConcatenated(int dimension, long[]... shapes)
public static void checkIndexIsValid(long[] shape, long... index)
public static void checkRankIsAtLeastTwo(long[] shape)