public class RandomForest extends java.lang.Object implements SoftClassifier<smile.data.Tuple>, DataFrameClassifier, TreeSHAP
Each tree is constructed using the following algorithm:
Modifier and Type | Class and Description |
---|---|
static class |
RandomForest.Model
The base model.
|
Constructor and Description |
---|
RandomForest(smile.data.formula.Formula formula,
int k,
RandomForest.Model[] models,
ClassificationMetrics metrics,
double[] importance)
Constructor.
|
RandomForest(smile.data.formula.Formula formula,
int k,
RandomForest.Model[] models,
ClassificationMetrics metrics,
double[] importance,
smile.util.IntSet labels)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
int[] classWeight)
Fits a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
int[] classWeight,
java.util.stream.LongStream seeds)
Fits a random forest for classification.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Fits a random forest for classification.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
RandomForest |
merge(RandomForest other)
Merges two random forests.
|
ClassificationMetrics |
metrics()
Returns the overall out-of-bag metric estimations.
|
RandomForest.Model[] |
models()
Returns the base models.
|
int |
predict(smile.data.Tuple x)
Predicts the class label of an instance.
|
int |
predict(smile.data.Tuple x,
double[] posteriori)
Predicts the class label of an instance and also calculate a posteriori
probabilities.
|
RandomForest |
prune(smile.data.DataFrame test)
Returns a new random forest by reduced error pruning.
|
smile.data.type.StructType |
schema()
Returns the design matrix schema.
|
int |
size()
Returns the number of trees in the model.
|
int[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
DecisionTree[] |
trees()
Returns the classification/regression trees.
|
RandomForest |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
int |
vote(smile.data.Tuple x,
double[] posteriori)
Predict and estimate the probability by voting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
predict
applyAsDouble, applyAsInt, predict, score
predict
public RandomForest(smile.data.formula.Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance)
formula
- a symbolic description of the model to be fitted.k
- the number of classes.models
- forest of decision trees.metrics
- the overall out-of-bag metric estimation.importance
- the feature importance.public RandomForest(smile.data.formula.Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance, smile.util.IntSet labels)
formula
- a symbolic description of the model to be fitted.k
- the number of classes.models
- the base models.metrics
- the overall out-of-bag metric estimation.importance
- the feature importance.labels
- the class labels.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. floor(sqrt(p)) generally
gives good performance, where p is the number of variablesmaxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree
will not split, nodeSize = 5 generally gives good
results.subsample
- the sampling rate for training tree. 1.0 means sampling
with replacement. < 1.0 means sampling without
replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. floor(sqrt(p)) generally
gives good performance, where p is the number of variablesmaxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree
will not split, nodeSize = 5 generally gives good
results.subsample
- the sampling rate for training tree. 1.0 means sampling
with replacement. < 1.0 means sampling without
replacement.classWeight
- Priors of the classes. The weight of each class
is roughly the ratio of samples in each class.
For example, if there are 400 positive samples
and 100 negative samples, the classWeight should
be [1, 4] (assuming label 0 is of negative, label
1 is of positive).public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, java.util.stream.LongStream seeds)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. floor(sqrt(p)) generally
gives good performance, where p is the number of variablesmaxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree
will not split, nodeSize = 5 generally gives good
results.subsample
- the sampling rate for training tree. 1.0 means sampling
with replacement. < 1.0 means sampling without
replacement.rule
- Decision tree split rule.classWeight
- Priors of the classes. The weight of each class
is roughly the ratio of samples in each class.
For example, if there are 400 positive samples
and 100 negative samples, the classWeight should
be [1, 4] (assuming label 0 is of negative, label 1 is of
positive).seeds
- optional RNG seeds for each regression tree.public smile.data.formula.Formula formula()
DataFrameClassifier
formula
in interface DataFrameClassifier
formula
in interface TreeSHAP
public smile.data.type.StructType schema()
DataFrameClassifier
schema
in interface DataFrameClassifier
public ClassificationMetrics metrics()
public double[] importance()
public int size()
public RandomForest.Model[] models()
public DecisionTree[] trees()
TreeSHAP
public RandomForest trim(int ntrees)
ntrees
- the new (smaller) size of tree model set.public RandomForest merge(RandomForest other)
public int predict(smile.data.Tuple x)
Classifier
predict
in interface Classifier<smile.data.Tuple>
predict
in interface DataFrameClassifier
x
- the instance to be classified.public int predict(smile.data.Tuple x, double[] posteriori)
SoftClassifier
predict
in interface SoftClassifier<smile.data.Tuple>
x
- an instance to be classified.posteriori
- a posteriori probabilities on output.public int vote(smile.data.Tuple x, double[] posteriori)
public int[][] test(smile.data.DataFrame data)
data
- the test data set.public RandomForest prune(smile.data.DataFrame test)
test
- the test data set to evaluate the errors of nodes.