public class RandomForest extends java.lang.Object implements Regression<smile.data.Tuple>, DataFrameRegression, TreeSHAP
Each tree is constructed using the following algorithm:
Modifier and Type | Class and Description |
---|---|
static class |
RandomForest.Model
The base model.
|
Regression.Metric
Constructor and Description |
---|
RandomForest(smile.data.formula.Formula formula,
RandomForest.Model[] models,
RegressionMetrics metrics,
double[] importance)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.stream.LongStream seeds)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Learns a random forest for regression.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
RandomForest |
merge(RandomForest other)
Merges two random forests.
|
RegressionMetrics |
metrics()
Returns the overall out-of-bag metric estimations.
|
RandomForest.Model[] |
models()
Returns the base models.
|
double |
predict(smile.data.Tuple x)
Predicts the dependent variable of an instance.
|
smile.data.type.StructType |
schema()
Returns the schema of predictors.
|
int |
size()
Returns the number of trees in the model.
|
double[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
RegressionTree[] |
trees()
Returns the classification/regression trees.
|
RandomForest |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
applyAsDouble, metric, metric, predict
predict
public RandomForest(smile.data.formula.Formula formula, RandomForest.Model[] models, RegressionMetrics metrics, double[] importance)
formula
- a symbolic description of the model to be fitted.models
- the base models.metrics
- the overall out-of-bag metric estimations.importance
- the feature importance.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.stream.LongStream seeds)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.seeds
- optional RNG seeds for each regression tree.public smile.data.formula.Formula formula()
DataFrameRegression
formula
in interface TreeSHAP
formula
in interface DataFrameRegression
public smile.data.type.StructType schema()
DataFrameRegression
schema
in interface DataFrameRegression
public RegressionMetrics metrics()
public double[] importance()
public int size()
public RandomForest.Model[] models()
public RegressionTree[] trees()
TreeSHAP
public RandomForest trim(int ntrees)
ntrees
- the new (smaller) size of tree model set.public RandomForest merge(RandomForest other)
public double predict(smile.data.Tuple x)
Regression
predict
in interface DataFrameRegression
predict
in interface Regression<smile.data.Tuple>
x
- an instance.public double[][] test(smile.data.DataFrame data)
data
- the test data set.