public class RandomForest extends java.lang.Object implements SoftClassifier<smile.data.Tuple>, DataFrameClassifier
Each tree is constructed using the following algorithm:
Constructor and Description |
---|
RandomForest(smile.data.formula.Formula formula,
int k,
java.util.List<smile.classification.RandomForest.Tree> trees,
double error,
double[] importance)
Constructor.
|
RandomForest(smile.data.formula.Formula formula,
int k,
java.util.List<smile.classification.RandomForest.Tree> trees,
double error,
double[] importance,
smile.util.IntSet labels)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
double |
error()
Returns the out-of-bag estimation of error rate.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.Optional<int[]> classWeight)
Fits a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.Optional<int[]> classWeight,
java.util.function.LongSupplier seedGenerator)
Fits a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.Optional<int[]> classWeight,
java.util.Optional<java.util.stream.LongStream> seeds)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Fits a random forest for classification.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
int |
predict(smile.data.Tuple x)
Predicts the class label of an instance.
|
int |
predict(smile.data.Tuple x,
double[] posteriori)
Predicts the class label of an instance and also calculate a posteriori
probabilities.
|
RandomForest |
prune(smile.data.DataFrame test)
Returns a new random forest by reduced error pruning.
|
smile.data.type.StructType |
schema()
Returns the design matrix schema.
|
int |
size()
Returns the number of trees in the model.
|
int[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
DecisionTree[] |
trees()
Returns the decision trees.
|
void |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
int |
vote(smile.data.Tuple x,
double[] posteriori)
Predict and estimate the probability by voting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
applyAsDouble, applyAsInt, f, predict
predict
public RandomForest(smile.data.formula.Formula formula, int k, java.util.List<smile.classification.RandomForest.Tree> trees, double error, double[] importance)
formula
- a symbolic description of the model to be fitted.k
- the number of classes.trees
- forest of decision trees.error
- the out-of-bag estimation of error rate.importance
- variable importancepublic RandomForest(smile.data.formula.Formula formula, int k, java.util.List<smile.classification.RandomForest.Tree> trees, double error, double[] importance, smile.util.IntSet labels)
formula
- a symbolic description of the model to be fitted.k
- the number of classes.trees
- forest of decision trees.error
- the out-of-bag estimation of error rate.importance
- variable importancelabels
- class labelspublic static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.Optional<int[]> classWeight)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.classWeight
- Priors of the classes. The weight of each class
is roughly the ratio of samples in each class.
For example, if there are 400 positive samples
and 100 negative samples, the classWeight should
be [1, 4] (assuming label 0 is of negative, label 1 is of
positive).public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.Optional<int[]> classWeight, java.util.function.LongSupplier seedGenerator)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.classWeight
- Priors of the classes. The weight of each class
is roughly the ratio of samples in each class.
For example, if there are 400 positive samples
and 100 negative samples, the classWeight should
be [1, 4] (assuming label 0 is of negative, label 1 is of
positive).seedGenerator
- RNG seed generator.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.Optional<int[]> classWeight, java.util.Optional<java.util.stream.LongStream> seeds)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of random selected features to be used to determine
the decision at a node of the tree. floor(sqrt(dim)) seems to give
generally good performance, where dim is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the minimum size of leaf nodes.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.rule
- Decision tree split rule.classWeight
- Priors of the classes. The weight of each class
is roughly the ratio of samples in each class.
For example, if there are 400 positive samples
and 100 negative samples, the classWeight should
be [1, 4] (assuming label 0 is of negative, label 1 is of
positive).seeds
- optional RNG seeds for each regression tree.public smile.data.formula.Formula formula()
DataFrameClassifier
formula
in interface DataFrameClassifier
public smile.data.type.StructType schema()
DataFrameClassifier
schema
in interface DataFrameClassifier
public double error()
public double[] importance()
public int size()
public DecisionTree[] trees()
public void trim(int ntrees)
ntrees
- the new (smaller) size of tree model set.public int predict(smile.data.Tuple x)
Classifier
predict
in interface Classifier<smile.data.Tuple>
predict
in interface DataFrameClassifier
x
- the instance to be classified.public int predict(smile.data.Tuple x, double[] posteriori)
SoftClassifier
predict
in interface SoftClassifier<smile.data.Tuple>
x
- an instance to be classified.posteriori
- the array to store a posteriori probabilities on output.public int vote(smile.data.Tuple x, double[] posteriori)
public int[][] test(smile.data.DataFrame data)
data
- the test data set.public RandomForest prune(smile.data.DataFrame test)
test
- the test data set to evaluate the errors of nodes.