Package org.nd4j.linalg.dataset
Class DataSet
- java.lang.Object
-
- org.nd4j.linalg.dataset.DataSet
-
- All Implemented Interfaces:
Serializable
,Iterable<DataSet>
,DataSet
public class DataSet extends Object implements DataSet
- See Also:
- Serialized Form
-
-
Constructor Summary
Constructors Constructor Description DataSet()
DataSet(INDArray first, INDArray second)
Creates a dataset with the specified input matrix and labelsDataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask)
Create a dataset with the specified input INDArray and labels (output) INDArray, plus (optionally) mask arrays for the features and labels
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Deprecated Methods Modifier and Type Method Description void
addFeatureVector(INDArray toAdd)
Adds a feature for each example on to the current feature vectorvoid
addFeatureVector(INDArray feature, int example)
The feature to add, and the example/row numbervoid
addRow(DataSet d, int i)
List<DataSet>
asList()
Extract each example in the DataSet into its own DataSet object, and return all of them as a listList<DataSet>
batchBy(int num)
Partitions a dataset in to mini batches where each dataset in each list is of the specified number of examplesList<DataSet>
batchByNumLabels()
void
binarize()
Same as calling binarize(0)void
binarize(double cutoff)
Binarizes the dataset such that any number greater than cutoff is 1 otherwise zeroDataSet
copy()
Clone the datasetList<DataSet>
dataSetBatches(int num)
Partitions the data transform by the specified number.void
detach()
This method detaches this DataSet from current Workspace (if any)void
divideBy(int num)
Divide the features by a scalarstatic DataSet
empty()
Returns a single dataset (all fields are null)boolean
equals(Object o)
INDArray
exampleMaxs()
INDArray
exampleMeans()
INDArray
exampleSums()
void
filterAndStrip(int[] labels)
Strips the dataset down to the specified labels and remaps themDataSet
filterBy(int[] labels)
Strips the data transform of all but the passed in labelsDataSet
get(int i)
Gets a copy of example iDataSet
get(int[] i)
Gets a copy of example iList<String>
getColumnNames()
Deprecated.List<Serializable>
getExampleMetaData()
Get the example metadata, or null if no metadata has been set<T extends Serializable>
List<T>getExampleMetaData(Class<T> metaDataType)
Get the example metadata, or null if no metadata has been set
Note: this method results in an unchecked cast - care should be taken when using this!INDArray
getFeatures()
Returns the features array for the DataSetINDArray
getFeaturesMaskArray()
Input mask array: a mask array for input, where each value is in {0,1} in order to specify whether an input is actually present or not.String
getLabelName(int idx)
List<String>
getLabelNames()
Deprecated.List<String>
getLabelNames(INDArray idxs)
List<String>
getLabelNamesList()
Gets the optional label namesINDArray
getLabels()
Returns the labels for the datasetINDArray
getLabelsMaskArray()
Labels (output) mask array: a mask array for input, where each value is in {0,1} in order to specify whether an output is actually present or not.long
getMemoryFootprint()
This method returns memory used by this DataSetDataSet
getRange(int from, int to)
int
hashCode()
boolean
hasMaskArrays()
Whether the labels or input (features) mask arrays are present for this DataSetString
id()
boolean
isEmpty()
boolean
isPreProcessed()
DataSetIterator
iterateWithMiniBatches()
Iterator<DataSet>
iterator()
Map<Integer,Double>
labelCounts()
Calculate and return a count of each label, by index.void
load(File from)
Load the contents of the DataSet from the specified File.void
load(InputStream from)
Load the contents of the DataSet from the specified InputStream.void
markAsPreProcessed()
static DataSet
merge(List<? extends DataSet> data)
Merge the list of datasets in to one list.void
migrate()
This method migrates this DataSet into current Workspace (if any)void
multiplyBy(double num)
Multiply the features by a scalarvoid
normalize()
Normalize this DataSet to mean 0, stdev 1 per input.void
normalizeZeroMeanZeroUnitVariance()
Deprecated.int
numExamples()
Number of examples in the DataSetint
numInputs()
The number of inputs in the feature matrixint
numOutcomes()
Returns the number of outcomes (size of the labels array for each example)int
outcome()
DataSet
reshape(int rows, int cols)
Reshapes the input in to the given rows and columnsvoid
roundToTheNearest(int roundTo)
DataSet
sample(int numSamples)
Sample without replacement and a random rngDataSet
sample(int numSamples, boolean withReplacement)
Sample a dataset numSamples timesDataSet
sample(int numSamples, Random rng)
Sample without replacementDataSet
sample(int numSamples, Random rng, boolean withReplacement)
Sample a datasetvoid
save(File to)
Save this DataSet to a file.void
save(OutputStream to)
Write the contents of this DataSet to the specified OutputStreamvoid
scale()
Divides the input data transform by the max number in each rowvoid
scaleMinAndMax(double min, double max)
void
setColumnNames(List<String> columnNames)
Deprecated.void
setExampleMetaData(List<? extends Serializable> exampleMetaData)
Set the metadata for this DataSet
By convention: the metadata can be any serializable object, one per example in the DataSetvoid
setFeatures(INDArray features)
Set the features array for the DataSetvoid
setFeaturesMaskArray(INDArray featuresMask)
Set the features mask array in this DataSetvoid
setLabelNames(List<String> labelNames)
Sets the label names, will throw an exception if the passed in label names doesn't equal the number of outcomesvoid
setLabels(INDArray labels)
void
setLabelsMaskArray(INDArray labelsMask)
Set the labels mask array in this data setvoid
setNewNumberOfLabels(int labels)
Clears the outcome matrix setting a new number of labelsvoid
setOutcome(int example, int label)
Sets the outcome of a particular examplevoid
shuffle()
Shuffle the order of the rows in the DataSet.void
shuffle(long seed)
Shuffles the dataset in place, given a seed for a random number generator.List<DataSet>
sortAndBatchByNumLabels()
Sorts the dataset by label: Splits the data transform such that examples are sorted by their labels.void
sortByLabel()
Organizes the dataset to minimize sampling error while still allowing efficient batching.SplitTestAndTrain
splitTestAndTrain(double fractionTrain)
SplitV the DataSet into two DataSets randomlySplitTestAndTrain
splitTestAndTrain(int numHoldout)
Splits a dataset in to test and trainSplitTestAndTrain
splitTestAndTrain(int numHoldout, Random rng)
Splits a dataset in to test and train randomly.void
squishToRange(double min, double max)
Squeezes input data to a max and a minMultiDataSet
toMultiDataSet()
String
toString()
void
validate()
-
Methods inherited from class java.lang.Object
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface java.lang.Iterable
forEach, spliterator
-
-
-
-
Constructor Detail
-
DataSet
public DataSet()
-
DataSet
public DataSet(INDArray first, INDArray second)
Creates a dataset with the specified input matrix and labels- Parameters:
first
- the feature matrixsecond
- the labels (these should be binarized label matrices such that the specified label has a value of 1 in the desired column with the label)
-
DataSet
public DataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask)
Create a dataset with the specified input INDArray and labels (output) INDArray, plus (optionally) mask arrays for the features and labels- Parameters:
features
- Features (input)labels
- Labels (output)featuresMask
- Mask array for features, may be nulllabelsMask
- Mask array for labels, may be null
-
-
Method Detail
-
getExampleMetaData
public List<Serializable> getExampleMetaData()
Description copied from interface:DataSet
Get the example metadata, or null if no metadata has been set- Specified by:
getExampleMetaData
in interfaceDataSet
- Returns:
- List of metadata instances
-
getExampleMetaData
public <T extends Serializable> List<T> getExampleMetaData(Class<T> metaDataType)
Description copied from interface:DataSet
Get the example metadata, or null if no metadata has been set
Note: this method results in an unchecked cast - care should be taken when using this!- Specified by:
getExampleMetaData
in interfaceDataSet
- Type Parameters:
T
- Type of metadata- Parameters:
metaDataType
- Class of the metadata (used for opType information)- Returns:
- List of metadata objects
-
setExampleMetaData
public void setExampleMetaData(List<? extends Serializable> exampleMetaData)
Description copied from interface:DataSet
Set the metadata for this DataSet
By convention: the metadata can be any serializable object, one per example in the DataSet- Specified by:
setExampleMetaData
in interfaceDataSet
- Parameters:
exampleMetaData
- Example metadata to set
-
isPreProcessed
public boolean isPreProcessed()
-
markAsPreProcessed
public void markAsPreProcessed()
-
empty
public static DataSet empty()
Returns a single dataset (all fields are null)- Returns:
- an empty dataset (all fields are null)
-
merge
public static DataSet merge(List<? extends DataSet> data)
Merge the list of datasets in to one list. All the rows are merged in to one dataset- Parameters:
data
- the data to merge- Returns:
- a single dataset
-
load
public void load(InputStream from)
Description copied from interface:DataSet
Load the contents of the DataSet from the specified InputStream. The current contents of the DataSet (if any) will be replaced.
The InputStream should contain a DataSet that has been serialized withDataSet.save(OutputStream)
-
load
public void load(File from)
Description copied from interface:DataSet
Load the contents of the DataSet from the specified File. The current contents of the DataSet (if any) will be replaced.
The InputStream should contain a DataSet that has been serialized withDataSet.save(File)
-
save
public void save(OutputStream to)
Description copied from interface:DataSet
Write the contents of this DataSet to the specified OutputStream
-
save
public void save(File to)
Description copied from interface:DataSet
Save this DataSet to a file. Can be loaded again using
-
iterateWithMiniBatches
public DataSetIterator iterateWithMiniBatches()
- Specified by:
iterateWithMiniBatches
in interfaceDataSet
-
getFeatures
public INDArray getFeatures()
Description copied from interface:DataSet
Returns the features array for the DataSet- Specified by:
getFeatures
in interfaceDataSet
- Returns:
- features array
-
setFeatures
public void setFeatures(INDArray features)
Description copied from interface:DataSet
Set the features array for the DataSet- Specified by:
setFeatures
in interfaceDataSet
- Parameters:
features
- Features to set
-
labelCounts
public Map<Integer,Double> labelCounts()
Description copied from interface:DataSet
Calculate and return a count of each label, by index. Assumes labels are a one-hot INDArray, for classification- Specified by:
labelCounts
in interfaceDataSet
- Returns:
- Map of countsn
-
copy
public DataSet copy()
Clone the dataset
-
reshape
public DataSet reshape(int rows, int cols)
Reshapes the input in to the given rows and columns
-
multiplyBy
public void multiplyBy(double num)
Description copied from interface:DataSet
Multiply the features by a scalar- Specified by:
multiplyBy
in interfaceDataSet
-
divideBy
public void divideBy(int num)
Description copied from interface:DataSet
Divide the features by a scalar
-
shuffle
public void shuffle()
Description copied from interface:DataSet
Shuffle the order of the rows in the DataSet. Note that this generally won't make any difference in practice unless the DataSet is later split.
-
shuffle
public void shuffle(long seed)
Shuffles the dataset in place, given a seed for a random number generator. For reproducibility This will modify the dataset in place!!- Parameters:
seed
- Seed to use for the random Number Generator
-
squishToRange
public void squishToRange(double min, double max)
Squeezes input data to a max and a min- Specified by:
squishToRange
in interfaceDataSet
- Parameters:
min
- the min value to occur in the datasetmax
- the max value to ccur in the dataset
-
scaleMinAndMax
public void scaleMinAndMax(double min, double max)
- Specified by:
scaleMinAndMax
in interfaceDataSet
-
scale
public void scale()
Divides the input data transform by the max number in each row
-
addFeatureVector
public void addFeatureVector(INDArray toAdd)
Adds a feature for each example on to the current feature vector- Specified by:
addFeatureVector
in interfaceDataSet
- Parameters:
toAdd
- the feature vector to add
-
addFeatureVector
public void addFeatureVector(INDArray feature, int example)
The feature to add, and the example/row number- Specified by:
addFeatureVector
in interfaceDataSet
- Parameters:
feature
- the feature vector to addexample
- the number of the example to append to
-
normalize
public void normalize()
Description copied from interface:DataSet
Normalize this DataSet to mean 0, stdev 1 per input. This calculates statistics based on the values in a single DataSet only. For normalization over multiple DataSet objects, useNormalizerStandardize
-
binarize
public void binarize()
Same as calling binarize(0)
-
binarize
public void binarize(double cutoff)
Binarizes the dataset such that any number greater than cutoff is 1 otherwise zero
-
normalizeZeroMeanZeroUnitVariance
@Deprecated public void normalizeZeroMeanZeroUnitVariance()
Deprecated.- Specified by:
normalizeZeroMeanZeroUnitVariance
in interfaceDataSet
-
numInputs
public int numInputs()
The number of inputs in the feature matrix
-
setNewNumberOfLabels
public void setNewNumberOfLabels(int labels)
Clears the outcome matrix setting a new number of labels- Specified by:
setNewNumberOfLabels
in interfaceDataSet
- Parameters:
labels
- the number of labels/columns in the outcome matrix Note that this clears the labels for each example
-
setOutcome
public void setOutcome(int example, int label)
Sets the outcome of a particular example- Specified by:
setOutcome
in interfaceDataSet
- Parameters:
example
- the example to transformlabel
- the label of the outcome
-
get
public DataSet get(int i)
Gets a copy of example i
-
get
public DataSet get(int[] i)
Gets a copy of example i
-
batchBy
public List<DataSet> batchBy(int num)
Partitions a dataset in to mini batches where each dataset in each list is of the specified number of examples
-
filterBy
public DataSet filterBy(int[] labels)
Strips the data transform of all but the passed in labels
-
filterAndStrip
public void filterAndStrip(int[] labels)
Strips the dataset down to the specified labels and remaps them- Specified by:
filterAndStrip
in interfaceDataSet
- Parameters:
labels
- the labels to strip down to
-
dataSetBatches
public List<DataSet> dataSetBatches(int num)
Partitions the data transform by the specified number.- Specified by:
dataSetBatches
in interfaceDataSet
- Parameters:
num
- the number to split by- Returns:
- the partitioned data transform
-
sortAndBatchByNumLabels
public List<DataSet> sortAndBatchByNumLabels()
Sorts the dataset by label: Splits the data transform such that examples are sorted by their labels. A ten label dataset would produce lists with batches like the following: x1 y = 1 x2 y = 2 ... x10 y = 10- Specified by:
sortAndBatchByNumLabels
in interfaceDataSet
- Returns:
- a list of data sets partitioned by outcomes
-
batchByNumLabels
public List<DataSet> batchByNumLabels()
- Specified by:
batchByNumLabels
in interfaceDataSet
-
asList
public List<DataSet> asList()
Description copied from interface:DataSet
Extract each example in the DataSet into its own DataSet object, and return all of them as a list
-
splitTestAndTrain
public SplitTestAndTrain splitTestAndTrain(int numHoldout, Random rng)
Splits a dataset in to test and train randomly. This will modify the dataset in place to shuffle it before splitting into test/train!- Specified by:
splitTestAndTrain
in interfaceDataSet
- Parameters:
numHoldout
- the number to hold out for trainingrng
- Random Number Generator to use to shuffle the dataset- Returns:
- the pair of datasets for the train test split
-
splitTestAndTrain
public SplitTestAndTrain splitTestAndTrain(int numHoldout)
Splits a dataset in to test and train- Specified by:
splitTestAndTrain
in interfaceDataSet
- Parameters:
numHoldout
- the number to hold out for training- Returns:
- the pair of datasets for the train test split
-
getLabels
public INDArray getLabels()
Returns the labels for the dataset
-
getLabelName
public String getLabelName(int idx)
- Specified by:
getLabelName
in interfaceDataSet
- Parameters:
idx
- the index to pullRows the string label value out of the list if it exists- Returns:
- the label opName
-
getLabelNames
public List<String> getLabelNames(INDArray idxs)
- Specified by:
getLabelNames
in interfaceDataSet
- Parameters:
idxs
- list of index to pullRows the string label value out of the list if it exists- Returns:
- the label opName
-
sortByLabel
public void sortByLabel()
Organizes the dataset to minimize sampling error while still allowing efficient batching.- Specified by:
sortByLabel
in interfaceDataSet
-
exampleSums
public INDArray exampleSums()
- Specified by:
exampleSums
in interfaceDataSet
-
exampleMaxs
public INDArray exampleMaxs()
- Specified by:
exampleMaxs
in interfaceDataSet
-
exampleMeans
public INDArray exampleMeans()
- Specified by:
exampleMeans
in interfaceDataSet
-
sample
public DataSet sample(int numSamples)
Sample without replacement and a random rng
-
sample
public DataSet sample(int numSamples, boolean withReplacement)
Sample a dataset numSamples times
-
roundToTheNearest
public void roundToTheNearest(int roundTo)
- Specified by:
roundToTheNearest
in interfaceDataSet
-
numOutcomes
public int numOutcomes()
Description copied from interface:DataSet
Returns the number of outcomes (size of the labels array for each example)- Specified by:
numOutcomes
in interfaceDataSet
-
numExamples
public int numExamples()
Description copied from interface:DataSet
Number of examples in the DataSet- Specified by:
numExamples
in interfaceDataSet
-
getLabelNames
@Deprecated public List<String> getLabelNames()
Deprecated.Gets the optional label names- Specified by:
getLabelNames
in interfaceDataSet
- Returns:
-
getLabelNamesList
public List<String> getLabelNamesList()
Gets the optional label names- Specified by:
getLabelNamesList
in interfaceDataSet
- Returns:
-
setLabelNames
public void setLabelNames(List<String> labelNames)
Sets the label names, will throw an exception if the passed in label names doesn't equal the number of outcomes- Specified by:
setLabelNames
in interfaceDataSet
- Parameters:
labelNames
- the label names to use
-
getColumnNames
@Deprecated public List<String> getColumnNames()
Deprecated.Optional column names of the data transform, this is mainly used for interpreting what columns are in the dataset- Specified by:
getColumnNames
in interfaceDataSet
- Returns:
-
setColumnNames
@Deprecated public void setColumnNames(List<String> columnNames)
Deprecated.Sets the column names, will throw an exception if the column names don't match the number of columns- Specified by:
setColumnNames
in interfaceDataSet
- Parameters:
columnNames
-
-
splitTestAndTrain
public SplitTestAndTrain splitTestAndTrain(double fractionTrain)
Description copied from interface:DataSet
SplitV the DataSet into two DataSets randomly- Specified by:
splitTestAndTrain
in interfaceDataSet
- Parameters:
fractionTrain
- Fraction (in range 0 to 1) of examples to be returned in the training DataSet object
-
getFeaturesMaskArray
public INDArray getFeaturesMaskArray()
Description copied from interface:DataSet
Input mask array: a mask array for input, where each value is in {0,1} in order to specify whether an input is actually present or not. Typically used for situations such as RNNs with variable length inputs- Specified by:
getFeaturesMaskArray
in interfaceDataSet
- Returns:
- Input mask array
-
setFeaturesMaskArray
public void setFeaturesMaskArray(INDArray featuresMask)
Description copied from interface:DataSet
Set the features mask array in this DataSet- Specified by:
setFeaturesMaskArray
in interfaceDataSet
-
getLabelsMaskArray
public INDArray getLabelsMaskArray()
Description copied from interface:DataSet
Labels (output) mask array: a mask array for input, where each value is in {0,1} in order to specify whether an output is actually present or not. Typically used for situations such as RNNs with variable length inputs or many- to-one situations.- Specified by:
getLabelsMaskArray
in interfaceDataSet
- Returns:
- Labels (output) mask array
-
setLabelsMaskArray
public void setLabelsMaskArray(INDArray labelsMask)
Description copied from interface:DataSet
Set the labels mask array in this data set- Specified by:
setLabelsMaskArray
in interfaceDataSet
-
hasMaskArrays
public boolean hasMaskArrays()
Description copied from interface:DataSet
Whether the labels or input (features) mask arrays are present for this DataSet- Specified by:
hasMaskArrays
in interfaceDataSet
-
getMemoryFootprint
public long getMemoryFootprint()
This method returns memory used by this DataSet- Specified by:
getMemoryFootprint
in interfaceDataSet
- Returns:
-
migrate
public void migrate()
Description copied from interface:DataSet
This method migrates this DataSet into current Workspace (if any)
-
detach
public void detach()
Description copied from interface:DataSet
This method detaches this DataSet from current Workspace (if any)
-
isEmpty
public boolean isEmpty()
-
toMultiDataSet
public MultiDataSet toMultiDataSet()
- Specified by:
toMultiDataSet
in interfaceDataSet
-
-