public class LinearRegressionGraph<OUTPUT> extends java.lang.Object implements ModelGraph<DoubleTensor,OUTPUT>
Modifier and Type | Class and Description |
---|---|
static class |
LinearRegressionGraph.OutputVertices<OUTPUT> |
Constructor and Description |
---|
LinearRegressionGraph(long[] featureShape,
java.util.function.Function<DoubleVertex,LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform,
DoubleVertex interceptVertex,
DoubleVertex weightsVertex) |
Modifier and Type | Method and Description |
---|---|
DoubleVertex |
getInterceptVertex() |
Vertex<OUTPUT> |
getOutputVertex() |
DoubleVertex |
getWeightVertex() |
void |
observeValues(DoubleTensor input,
OUTPUT output) |
OUTPUT |
predict(DoubleTensor input) |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getBayesianNetwork
public LinearRegressionGraph(long[] featureShape, java.util.function.Function<DoubleVertex,LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform, DoubleVertex interceptVertex, DoubleVertex weightsVertex)
public OUTPUT predict(DoubleTensor input)
public void observeValues(DoubleTensor input, OUTPUT output)
observeValues
in interface ModelGraph<DoubleTensor,OUTPUT>
public DoubleVertex getInterceptVertex()
public DoubleVertex getWeightVertex()