public class TensorShapeValidation
extends java.lang.Object
Modifier and Type | Method and Description |
---|---|
static int[] |
checkAllShapesMatch(java.util.Collection<int[]> shapes) |
static int[] |
checkAllShapesMatch(int[]... shapes) |
static int[] |
checkHasSingleNonScalarShapeOrAllScalar(int[]... shapes)
This ensures there is at most a single non-scalar shape.
|
static void |
checkTensorsMatchNonScalarShapeOrAreScalar(int[] proposalShape,
int[]... shapes)
This is a common function to check that tensors are either
the same shape of the proposal in question OR scalar.
|
public static void checkTensorsMatchNonScalarShapeOrAreScalar(int[] proposalShape, int[]... shapes)
proposalShape
- the tensor shape being validatedshapes
- the tensors being validated againstjava.lang.IllegalArgumentException
- if there is more than one non-scalar shape OR if the non-scalar shape does
not match the proposal shape.public static int[] checkHasSingleNonScalarShapeOrAllScalar(int[]... shapes)
shapes
- the tensors for shape checkingjava.lang.IllegalArgumentException
- if there is more than one non-scalar shapepublic static int[] checkAllShapesMatch(int[]... shapes)
public static int[] checkAllShapesMatch(java.util.Collection<int[]> shapes)