public abstract class CART extends java.lang.Object implements SHAP<smile.data.Tuple>, java.io.Serializable
Modifier and Type | Field and Description |
---|---|
protected smile.data.formula.Formula |
formula
The model formula.
|
protected double[] |
importance
Variable importance.
|
protected int[] |
index
An index of samples to their original locations in training dataset.
|
protected int |
maxDepth
The maximum depth of the tree.
|
protected int |
maxNodes
The maximum number of leaf nodes in the tree.
|
protected int |
mtry
The number of input variables to be used to determine the decision
at a node of the tree.
|
protected int |
nodeSize
The number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.
|
protected int[][] |
order
An index of training values.
|
protected smile.data.type.StructField |
response
The schema of response variable.
|
protected Node |
root
The root of decision tree.
|
protected int[] |
samples
The samples for training this node.
|
protected smile.data.type.StructType |
schema
The schema of predictors.
|
protected smile.data.DataFrame |
x
The training data.
|
Constructor and Description |
---|
CART(smile.data.DataFrame x,
smile.data.type.StructField y,
int maxDepth,
int maxNodes,
int nodeSize,
int mtry,
int[] samples,
int[][] order)
Constructor.
|
CART(smile.data.formula.Formula formula,
smile.data.type.StructType schema,
smile.data.type.StructField response,
Node root,
double[] importance)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
protected void |
clear()
Clear the workspace of building tree.
|
java.lang.String |
dot()
Returns the graphic representation in Graphviz dot format.
|
protected abstract java.util.Optional<Split> |
findBestSplit(LeafNode node,
int column,
double impurity,
int lo,
int hi)
Finds the best split for given column.
|
protected java.util.Optional<Split> |
findBestSplit(LeafNode node,
int lo,
int hi,
boolean[] unsplittable)
Finds the best attribute to split on a set of samples.
|
double[] |
importance()
Returns the variable importance.
|
protected abstract double |
impurity(LeafNode node)
Returns the impurity of node.
|
protected abstract LeafNode |
newNode(int[] nodeSamples)
Creates a new leaf node.
|
static int[][] |
order(smile.data.DataFrame x)
Returns the index of ordered samples for each ordinal column.
|
protected smile.data.Tuple |
predictors(smile.data.Tuple x)
Returns the predictors by the model formula if it is not null.
|
Node |
root()
Returs the root node.
|
double[] |
shap(smile.data.DataFrame data)
Returns the average of absolute SHAP values over a data frame.
|
double[] |
shap(smile.data.Tuple x)
Returns the SHAP values.
|
int |
size()
Returns the number of nodes in the tree.
|
protected boolean |
split(Split split,
java.util.PriorityQueue<Split> queue)
Split a node into two children nodes.
|
java.lang.String |
toString()
Returns a text representation of the tree in R's rpart format.
|
protected smile.data.formula.Formula formula
protected smile.data.type.StructType schema
protected smile.data.type.StructField response
protected Node root
protected int maxDepth
protected int maxNodes
protected int nodeSize
protected int mtry
protected double[] importance
protected transient smile.data.DataFrame x
protected transient int[] samples
protected transient int[] index
protected transient int[][] order
public CART(smile.data.formula.Formula formula, smile.data.type.StructType schema, smile.data.type.StructField response, Node root, double[] importance)
public CART(smile.data.DataFrame x, smile.data.type.StructField y, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order)
x
- the data frame of the explanatory variable.y
- the response variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the minimum size of leaf nodes.mtry
- the number of input variables to pick to split on at each
node. It seems that sqrt(p) give generally good performance,
where p is the number of variables.samples
- the sample set of instances for stochastic learning.
samples[i] is the number of sampling for instance i.order
- the index of training values in ascending order. Note
that only numeric attributes need be sorted.public int size()
public static int[][] order(smile.data.DataFrame x)
protected smile.data.Tuple predictors(smile.data.Tuple x)
protected void clear()
protected boolean split(Split split, java.util.PriorityQueue<Split> queue)
protected java.util.Optional<Split> findBestSplit(LeafNode node, int lo, int hi, boolean[] unsplittable)
node
- the leaf node to split.lo
- the inclusive lower bound of the data partition in the reordered sample index array.hi
- the exclusive upper bound of the data partition in the reordered sample index array.unsplittable
- unsplittable[j] is true if the column j cannot be split further in the node.protected abstract double impurity(LeafNode node)
protected abstract LeafNode newNode(int[] nodeSamples)
protected abstract java.util.Optional<Split> findBestSplit(LeafNode node, int column, double impurity, int lo, int hi)
public double[] importance()
public Node root()
public java.lang.String dot()
public java.lang.String toString()
toString
in class java.lang.Object
public double[] shap(smile.data.DataFrame data)
public double[] shap(smile.data.Tuple x)
SHAP
p x k
, where p
is the number of
features and k
is the classes. The first k elements are
the SHAP values of first feature over k classes, respectively. The
rest features follow accordingly.