public class RandomForest extends java.lang.Object implements Regression<smile.data.Tuple>, DataFrameRegression
Each tree is constructed using the following algorithm:
Constructor and Description |
---|
RandomForest(smile.data.formula.Formula formula,
RegressionTree[] trees,
double error,
double[] importance)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
double |
error()
Returns the out-of-bag estimation of RMSE.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.function.LongSupplier seedGenerator)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.Optional<java.util.stream.LongStream> seeds)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Learns a random forest for regression.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
RandomForest |
merge(RandomForest other)
Merges together two random forests and returns a new forest consisting of trees from both input forests.
|
double |
predict(smile.data.Tuple x)
Predicts the dependent variable of an instance.
|
smile.data.type.StructType |
schema()
Returns the design matrix schema.
|
int |
size()
Returns the number of trees in the model.
|
double[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
RegressionTree[] |
trees()
Returns the regression trees.
|
void |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
applyAsDouble, predict
predict
public RandomForest(smile.data.formula.Formula formula, RegressionTree[] trees, double error, double[] importance)
formula
- a symbolic description of the model to be fitted.trees
- forest of regression trees.error
- out-of-bag estimation of RMSEimportance
- variable importancepublic static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.function.LongSupplier seedGenerator)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.seedGenerator
- RNG seed generator.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.Optional<java.util.stream.LongStream> seeds)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision
at a node of the tree. p/3 seems to give generally good performance,
where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
sampling without replacement.seeds
- optional RNG seeds for each regression tree.public RandomForest merge(RandomForest other)
public smile.data.formula.Formula formula()
DataFrameRegression
formula
in interface DataFrameRegression
public smile.data.type.StructType schema()
DataFrameRegression
schema
in interface DataFrameRegression
public double error()
public double[] importance()
public int size()
public RegressionTree[] trees()
public void trim(int ntrees)
ntrees
- the new (smaller) size of tree model set.public double predict(smile.data.Tuple x)
Regression
predict
in interface DataFrameRegression
predict
in interface Regression<smile.data.Tuple>
x
- an instance.public double[][] test(smile.data.DataFrame data)
data
- the test data set.