public class RandomForest extends java.lang.Object implements Regression<smile.data.Tuple>, DataFrameRegression
Each tree is constructed using the following algorithm:
| Constructor and Description |
|---|
RandomForest(smile.data.formula.Formula formula,
RegressionTree[] trees,
double error,
double[] importance)
Constructor.
|
| Modifier and Type | Method and Description |
|---|---|
double |
error()
Returns the out-of-bag estimation of RMSE.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.function.LongSupplier seedGenerator)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.Optional<java.util.stream.LongStream> seeds)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Learns a random forest for regression.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
RandomForest |
merge(RandomForest other)
Merges together two random forests and returns a new forest consisting of trees from both input forests.
|
double |
predict(smile.data.Tuple x)
Predicts the dependent variable of an instance.
|
smile.data.type.StructType |
schema()
Returns the design matrix schema.
|
int |
size()
Returns the number of trees in the model.
|
double[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
RegressionTree[] |
trees()
Returns the regression trees.
|
void |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitapplyAsDouble, predictpredictpublic RandomForest(smile.data.formula.Formula formula,
RegressionTree[] trees,
double error,
double[] importance)
formula - a symbolic description of the model to be fitted.trees - forest of regression trees.error - out-of-bag estimation of RMSEimportance - variable importancepublic static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula - a symbolic description of the model to be fitted.data - the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula - a symbolic description of the model to be fitted.data - the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula - a symbolic description of the model to be fitted.data - the data frame of the explanatory and response variables.ntrees - the number of trees.mtry - the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth - the maximum depth of the tree.maxNodes - the maximum number of leaf nodes in the tree.nodeSize - the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample - the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.function.LongSupplier seedGenerator)
formula - a symbolic description of the model to be fitted.data - the data frame of the explanatory and response variables.ntrees - the number of trees.mtry - the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth - the maximum depth of the tree.maxNodes - the maximum number of leaf nodes in the tree.nodeSize - the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample - the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.seedGenerator - RNG seed generator.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.Optional<java.util.stream.LongStream> seeds)
formula - a symbolic description of the model to be fitted.data - the data frame of the explanatory and response variables.ntrees - the number of trees.mtry - the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth - the maximum depth of the tree.maxNodes - the maximum number of leaf nodes in the tree.nodeSize - the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample - the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.seeds - optional RNG seeds for each regression tree.public RandomForest merge(RandomForest other)
public smile.data.formula.Formula formula()
DataFrameRegressionformula in interface DataFrameRegressionpublic smile.data.type.StructType schema()
DataFrameRegressionschema in interface DataFrameRegressionpublic double error()
public double[] importance()
public int size()
public RegressionTree[] trees()
public void trim(int ntrees)
ntrees - the new (smaller) size of tree model set.public double predict(smile.data.Tuple x)
Regressionpredict in interface DataFrameRegressionpredict in interface Regression<smile.data.Tuple>x - an instance.public double[][] test(smile.data.DataFrame data)
data - the test data set.