public class TensorShapeValidation
extends java.lang.Object
Modifier and Type | Method and Description |
---|---|
static long[] |
checkAllShapesMatch(java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(long[]... shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
long[]... shapes) |
static void |
checkDimensionExistsInShape(int dimension,
long[] shape)
Check if the given dimension exists within the shape
|
static long[] |
checkHasOneNonLengthOneShapeOrAllLengthOne(long[]... shapes)
This ensures there is at most a single non length one shape.
|
static void |
checkIndexIsValid(long[] shape,
long... index) |
static long[] |
checkIsBroadcastable(long[] left,
long[] right) |
static void |
checkShapeIsSquareMatrix(long[] shape) |
static long[] |
checkShapesCanBeConcatenated(int dimension,
long[]... shapes) |
static void |
checkShapesMatch(long[] actual,
long[] expected) |
static void |
checkTensorsAreScalar(java.lang.String message,
long[]... shapes) |
static void |
checkTensorsMatchNonLengthOneShapeOrAreLengthOne(long[] proposalShape,
long[]... shapes)
This is a common function to check that tensors are either
the same shape of the proposal in question OR length one.
|
static long[] |
checkTernaryConditionShapeIsValid(long[] predicate,
long[] thn,
long[] els) |
static long[] |
getMatrixMultiplicationResultingShape(long[] left,
long[] right) |
static long[] |
getTensorMultiplyResultShape(long[] leftShape,
long[] rightShape,
int[] dimsLeft,
int[] dimsRight) |
public static void checkTensorsMatchNonLengthOneShapeOrAreLengthOne(long[] proposalShape, long[]... shapes)
proposalShape
- the tensor shape being validatedshapes
- the tensors being validated againstjava.lang.IllegalArgumentException
- if there is more than one non length one shape OR if the non length one shape does
not match the proposal shape.public static void checkDimensionExistsInShape(int dimension, long[] shape)
dimension
- Proposed dimensionshape
- Shape to checkjava.lang.IllegalArgumentException
- if the dimension exceeds the rank of the shapepublic static void checkTensorsAreScalar(java.lang.String message, long[]... shapes)
public static long[] checkHasOneNonLengthOneShapeOrAllLengthOne(long[]... shapes)
shapes
- the tensors for shape checkingjava.lang.IllegalArgumentException
- if there is more than one non length one shape or multiple ranks of length 1 shapespublic static long[] checkIsBroadcastable(long[] left, long[] right)
public static long[] checkTernaryConditionShapeIsValid(long[] predicate, long[] thn, long[] els)
predicate
- shape of predicatethn
- shape of thenels
- shape of elsepublic static void checkShapeIsSquareMatrix(long[] shape)
public static void checkShapesMatch(long[] actual, long[] expected)
public static long[] checkAllShapesMatch(long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage, long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage, java.util.Collection<long[]> shapes)
public static long[] checkAllShapesMatch(java.util.Collection<long[]> shapes)
public static long[] checkShapesCanBeConcatenated(int dimension, long[]... shapes)
public static void checkIndexIsValid(long[] shape, long... index)
public static long[] getTensorMultiplyResultShape(long[] leftShape, long[] rightShape, int[] dimsLeft, int[] dimsRight)
public static long[] getMatrixMultiplicationResultingShape(long[] left, long[] right)