public final class Tensor extends Object implements AutoCloseable
Instances of a Tensor are not thread-safe.
WARNING: Resources consumed by the Tensor object must be explicitly freed by
invoking the close()
method when the object is no longer needed. For example, using a
try-with-resources block like:
try(Tensor t = Tensor.create(...)) {
doSomethingWith(t);
}
Modifier and Type | Method and Description |
---|---|
boolean |
booleanValue()
Returns the value in a scalar
DataType.BOOL tensor. |
byte[] |
bytesValue()
Returns the value in a scalar
DataType.STRING tensor. |
void |
close()
Release resources associated with the Tensor.
|
<T> T |
copyTo(T dst)
Copies the contents of the tensor to
dst and returns dst . |
static Tensor |
create(DataType dataType,
long[] shape,
ByteBuffer data)
Create a Tensor with data from the given buffer.
|
static Tensor |
create(long[] shape,
DoubleBuffer data)
Create a
DataType.DOUBLE Tensor with data from the given buffer. |
static Tensor |
create(long[] shape,
FloatBuffer data)
Create a
DataType.FLOAT Tensor with data from the given buffer. |
static Tensor |
create(long[] shape,
IntBuffer data)
Create an
DataType.INT32 Tensor with data from the given buffer. |
static Tensor |
create(long[] shape,
LongBuffer data)
Create an
DataType.INT64 Tensor with data from the given buffer. |
static Tensor |
create(Object obj)
Create a Tensor from a Java object.
|
DataType |
dataType()
Returns the
DataType of elements stored in the Tensor. |
double |
doubleValue()
Returns the value in a scalar
DataType.DOUBLE tensor. |
float |
floatValue()
Returns the value in a scalar
DataType.FLOAT tensor. |
int |
intValue()
Returns the value in a scalar
DataType.INT32 tensor. |
long |
longValue()
Returns the value in a scalar
DataType.INT64 tensor. |
int |
numBytes()
Returns the size, in bytes, of the tensor data.
|
int |
numDimensions()
Returns the number of dimensions (sometimes referred to as rank) of the Tensor.
|
int |
numElements()
Returns the number of elements in a flattened (1-D) view of the tensor.
|
long[] |
shape()
Returns the shape of
the Tensor, i.e., the sizes of each dimension.
|
String |
toString()
Returns a string describing the type and shape of the Tensor.
|
void |
writeTo(ByteBuffer dst)
Write the tensor data into the given buffer.
|
void |
writeTo(DoubleBuffer dst)
Write the data of a
DataType.DOUBLE tensor into the given buffer. |
void |
writeTo(FloatBuffer dst)
Write the data of a
DataType.FLOAT tensor into the given buffer. |
void |
writeTo(IntBuffer dst)
Write the data of a
DataType.INT32 tensor into the given buffer. |
void |
writeTo(LongBuffer dst)
Write the data of a
DataType.INT64 tensor into the given buffer. |
public static Tensor create(Object obj)
A Tensor is a multi-dimensional array of elements of a limited set of types (DataType
). Thus, not all Java objects can be converted to a Tensor. In particular, obj
must be either a primitive (float, double, int, long, boolean) or a multi-dimensional array of
one of those primitives. For example:
// Valid: A 64-bit integer scalar.
Tensor s = Tensor.create(42L);
// Valid: A 3x2 matrix of floats.
float[][] matrix = new float[3][2];
Tensor m = Tensor.create(matrix);
// Invalid: Will throw an IllegalArgumentException as an arbitrary Object
// does not fit into the TensorFlow type system.
Tensor o = Tensor.create(new Object());
// Invalid: Will throw an IllegalArgumentException since there are
// a differing number of elements in each row of this 2-D array.
int[][] twoD = new int[2][];
twoD[0] = new int[1];
twoD[1] = new int[2];
Tensor x = Tensor.create(twoD);
IllegalArgumentException
- if obj
is not compatible with the TensorFlow type
system, or if obj does not disambiguate between multiple DataTypes. In that case, consider
using create(DataType, long[], ByteBuffer)
instead.public static Tensor create(long[] shape, IntBuffer data)
DataType.INT32
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor create(long[] shape, FloatBuffer data)
DataType.FLOAT
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor create(long[] shape, DoubleBuffer data)
DataType.DOUBLE
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor create(long[] shape, LongBuffer data)
DataType.INT64
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor create(DataType dataType, long[] shape, ByteBuffer data)
Creates a Tensor with the provided shape of any type where the tensor's data has been
encoded into data
as per the specification of the TensorFlow C API.
dataType
- the tensor datatype.shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor datatype or shape is not compatible with the
bufferpublic void close()
WARNING:If not invoked, memory will be leaked.
The Tensor object is no longer usable after close
returns.
close
in interface AutoCloseable
public int numDimensions()
Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
public int numBytes()
public int numElements()
public long[] shape()
public float floatValue()
DataType.FLOAT
tensor.IllegalArgumentException
- if the Tensor does not represent a float scalar.public double doubleValue()
DataType.DOUBLE
tensor.IllegalArgumentException
- if the Tensor does not represent a double scalar.public int intValue()
DataType.INT32
tensor.IllegalArgumentException
- if the Tensor does not represent a int scalar.public long longValue()
DataType.INT64
tensor.IllegalArgumentException
- if the Tensor does not represent a long scalar.public boolean booleanValue()
DataType.BOOL
tensor.IllegalArgumentException
- if the Tensor does not represent a boolean scalar.public byte[] bytesValue()
DataType.STRING
tensor.IllegalArgumentException
- if the Tensor does not represent a boolean scalar.public <T> T copyTo(T dst)
dst
and returns dst
.
For non-scalar tensors, this method copies the contents of the underlying tensor to a Java
array. For scalar tensors, use one of floatValue()
, doubleValue()
, intValue()
, longValue()
or booleanValue()
instead. The type and shape of
dst
must be compatible with the tensor. For example:
int matrix[2][2] = {{1,2},{3,4}};
try(Tensor t = Tensor.create(matrix)) {
// Succeeds and prints "3"
int[][] copy = new int[2][2];
System.out.println(t.copyTo(copy)[1][0]);
// Throws IllegalArgumentException since the shape of dst does not match the shape of t.
int[][] dst = new int[4][1];
t.copyTo(dst);
}
IllegalArgumentException
- if the tensor is a scalar or if dst
is not compatible
with the tensor (for example, mismatched data types or shapes).public void writeTo(IntBuffer dst)
DataType.INT32
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not DataType.INT32
public void writeTo(FloatBuffer dst)
DataType.FLOAT
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not DataType.FLOAT
public void writeTo(DoubleBuffer dst)
DataType.DOUBLE
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not DataType.DOUBLE
public void writeTo(LongBuffer dst)
DataType.INT64
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not DataType.INT64
public void writeTo(ByteBuffer dst)
Copies numBytes()
bytes to the buffer in native byte order for primitive types.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorCopyright © 2015–2017. All rights reserved.