public class INDArrayShim
extends java.lang.Object
To work around another issue in ND4J where you cannot broadcast a higher rank tensor onto a lower rank tensor, the shim broadcast operations ensure the higher rank tensor is always being operated on. In the case of subtract and minus, this requires a small change in the logic, as A - B != B - A and A / B != B / A.
Constructor and Description |
---|
INDArrayShim() |
Modifier and Type | Method and Description |
---|---|
static org.nd4j.linalg.api.ndarray.INDArray |
addi(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
atan2(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
divi(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
eq(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
getGreaterThanMask(org.nd4j.linalg.api.ndarray.INDArray mask,
org.nd4j.linalg.api.ndarray.INDArray right,
org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType) |
static org.nd4j.linalg.api.ndarray.INDArray |
getGreaterThanOrEqualToMask(org.nd4j.linalg.api.ndarray.INDArray mask,
org.nd4j.linalg.api.ndarray.INDArray right,
org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType) |
static org.nd4j.linalg.api.ndarray.INDArray |
getLessThanMask(org.nd4j.linalg.api.ndarray.INDArray mask,
org.nd4j.linalg.api.ndarray.INDArray right,
org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType) |
static org.nd4j.linalg.api.ndarray.INDArray |
getLessThanOrEqualToMask(org.nd4j.linalg.api.ndarray.INDArray mask,
org.nd4j.linalg.api.ndarray.INDArray right,
org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType) |
static org.nd4j.linalg.api.ndarray.INDArray |
gt(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
lt(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
max(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
min(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
muli(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
pow(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
slice(org.nd4j.linalg.api.ndarray.INDArray tensor,
int dimension,
long index) |
static void |
startNewThreadForNd4j() |
static org.nd4j.linalg.api.ndarray.INDArray |
subi(org.nd4j.linalg.api.ndarray.INDArray left,
org.nd4j.linalg.api.ndarray.INDArray right) |
static org.nd4j.linalg.api.ndarray.INDArray |
sum(org.nd4j.linalg.api.ndarray.INDArray tensor,
int... overDimensions) |
public static void startNewThreadForNd4j()
public static org.nd4j.linalg.api.ndarray.INDArray muli(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray divi(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray addi(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray subi(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray pow(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray max(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray min(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray atan2(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray lt(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray gt(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray eq(org.nd4j.linalg.api.ndarray.INDArray left, org.nd4j.linalg.api.ndarray.INDArray right)
public static org.nd4j.linalg.api.ndarray.INDArray getGreaterThanMask(org.nd4j.linalg.api.ndarray.INDArray mask, org.nd4j.linalg.api.ndarray.INDArray right, org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType)
public static org.nd4j.linalg.api.ndarray.INDArray getGreaterThanOrEqualToMask(org.nd4j.linalg.api.ndarray.INDArray mask, org.nd4j.linalg.api.ndarray.INDArray right, org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType)
public static org.nd4j.linalg.api.ndarray.INDArray getLessThanMask(org.nd4j.linalg.api.ndarray.INDArray mask, org.nd4j.linalg.api.ndarray.INDArray right, org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType)
public static org.nd4j.linalg.api.ndarray.INDArray getLessThanOrEqualToMask(org.nd4j.linalg.api.ndarray.INDArray mask, org.nd4j.linalg.api.ndarray.INDArray right, org.nd4j.linalg.api.buffer.DataBuffer.Type bufferType)
public static org.nd4j.linalg.api.ndarray.INDArray sum(org.nd4j.linalg.api.ndarray.INDArray tensor, int... overDimensions)
public static org.nd4j.linalg.api.ndarray.INDArray slice(org.nd4j.linalg.api.ndarray.INDArray tensor, int dimension, long index)