public class Hyperparameters
extends java.lang.Object
Hyperparameters can be classified as model hyperparameters, that cannot be inferred while fitting the machine to the training set because they refer to the model selection task, or algorithm hyperparameters, that in principle have no influence on the performance of the model but affect the speed and quality of the learning process. For example, the topology and size of a neural network are model hyperparameters, while learning rate and mini-batch size are algorithm hyperparameters.
The below example shows how to tune the hyperparameters of random forest.
import smile.io.*;
import smile.data.formula.Formula;
import smile.validation.*;
import smile.classification.RandomForest;
var hp = new Hyperparameters()
.add("smile.random.forest.trees", 100) // a fixed value
.add("smile.random.forest.mtry", new int[] {2, 3, 4}) // an array of values to choose
.add("smile.random.forest.max.nodes", 100, 500, 50); // range [100, 500] with step 50
var train = Read.arff("data/weka/segment-challenge.arff");
var test = Read.arff("data/weka/segment-test.arff");
var formula = Formula.lhs("class");
var testy = formula.y(test).toIntArray();
hp.grid().forEach(prop -> {
var model = RandomForest.fit(formula, train, prop);
var pred = model.predict(test);
System.out.println(prop);
System.out.format("Accuracy = %.2f%%%n", (100.0 * Accuracy.of(testy, pred)));
System.out.println(ConfusionMatrix.of(testy, pred));
});
Constructor and Description |
---|
Hyperparameters()
Constructor.
|
Modifier and Type | Method and Description |
---|---|
Hyperparameters |
add(java.lang.String name,
double value)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
double[] values)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
double start,
double end)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
double start,
double end,
double step)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
int value)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
int[] values)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
int start,
int end)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
int start,
int end,
int step)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
java.lang.String value)
Adds a parameter.
|
Hyperparameters |
add(java.lang.String name,
java.lang.String[] values)
Adds a parameter.
|
java.util.stream.Stream<java.util.Properties> |
grid()
Generates a stream of hyperparameters for grid search.
|
java.util.stream.Stream<java.util.Properties> |
random()
Generates a stream of hyperparameters for random search.
|
public Hyperparameters add(java.lang.String name, int value)
name
- the parameter name.value
- a fixed value of parameter.public Hyperparameters add(java.lang.String name, double value)
name
- the parameter name.value
- a fixed value of parameter.public Hyperparameters add(java.lang.String name, java.lang.String value)
name
- the parameter name.value
- a fixed value of parameter.public Hyperparameters add(java.lang.String name, int[] values)
name
- the parameter name.values
- an array of parameter values.public Hyperparameters add(java.lang.String name, double[] values)
name
- the parameter name.values
- an array of parameter values.public Hyperparameters add(java.lang.String name, java.lang.String[] values)
name
- the parameter name.values
- an array of parameter values.public Hyperparameters add(java.lang.String name, int start, int end)
name
- the parameter name.start
- the start of value range (inclusive).end
- the end of value range (inclusive).public Hyperparameters add(java.lang.String name, int start, int end, int step)
name
- the parameter name.start
- the start of value range (inclusive).end
- the end of value range (inclusive).step
- the step size.public Hyperparameters add(java.lang.String name, double start, double end)
name
- the parameter name.start
- the start of value range (inclusive).end
- the end of value range (inclusive).public Hyperparameters add(java.lang.String name, double start, double end, double step)
name
- the parameter name.start
- the start of value range (inclusive).end
- the end of value range (inclusive).step
- the step size.public java.util.stream.Stream<java.util.Properties> random()
public java.util.stream.Stream<java.util.Properties> grid()