001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.analysis.function; 019 020 import org.apache.commons.math3.analysis.FunctionUtils; 021 import org.apache.commons.math3.analysis.UnivariateFunction; 022 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 023 import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 024 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 026 import org.apache.commons.math3.exception.NotStrictlyPositiveException; 027 import org.apache.commons.math3.exception.NullArgumentException; 028 import org.apache.commons.math3.exception.DimensionMismatchException; 029 import org.apache.commons.math3.util.FastMath; 030 031 /** 032 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function"> 033 * Generalised logistic</a> function. 034 * 035 * @since 3.0 036 * @version $Id: Logistic.java 1391927 2012-09-30 00:03:30Z erans $ 037 */ 038 public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 039 /** Lower asymptote. */ 040 private final double a; 041 /** Upper asymptote. */ 042 private final double k; 043 /** Growth rate. */ 044 private final double b; 045 /** Parameter that affects near which asymptote maximum growth occurs. */ 046 private final double oneOverN; 047 /** Parameter that affects the position of the curve along the ordinate axis. */ 048 private final double q; 049 /** Abscissa of maximum growth. */ 050 private final double m; 051 052 /** 053 * @param k If {@code b > 0}, value of the function for x going towards +∞. 054 * If {@code b < 0}, value of the function for x going towards -∞. 055 * @param m Abscissa of maximum growth. 056 * @param b Growth rate. 057 * @param q Parameter that affects the position of the curve along the 058 * ordinate axis. 059 * @param a If {@code b > 0}, value of the function for x going towards -∞. 060 * If {@code b < 0}, value of the function for x going towards +∞. 061 * @param n Parameter that affects near which asymptote the maximum 062 * growth occurs. 063 * @throws NotStrictlyPositiveException if {@code n <= 0}. 064 */ 065 public Logistic(double k, 066 double m, 067 double b, 068 double q, 069 double a, 070 double n) 071 throws NotStrictlyPositiveException { 072 if (n <= 0) { 073 throw new NotStrictlyPositiveException(n); 074 } 075 076 this.k = k; 077 this.m = m; 078 this.b = b; 079 this.q = q; 080 this.a = a; 081 oneOverN = 1 / n; 082 } 083 084 /** {@inheritDoc} */ 085 public double value(double x) { 086 return value(m - x, k, b, q, a, oneOverN); 087 } 088 089 /** {@inheritDoc} 090 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 091 */ 092 @Deprecated 093 public UnivariateFunction derivative() { 094 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 095 } 096 097 /** 098 * Parametric function where the input array contains the parameters of 099 * the logit function, ordered as follows: 100 * <ul> 101 * <li>Lower asymptote</li> 102 * <li>Higher asymptote</li> 103 * </ul> 104 */ 105 public static class Parametric implements ParametricUnivariateFunction { 106 /** 107 * Computes the value of the sigmoid at {@code x}. 108 * 109 * @param x Value for which the function must be computed. 110 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 111 * {@code a} and {@code n}. 112 * @return the value of the function. 113 * @throws NullArgumentException if {@code param} is {@code null}. 114 * @throws DimensionMismatchException if the size of {@code param} is 115 * not 6. 116 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 117 */ 118 public double value(double x, double ... param) 119 throws NullArgumentException, 120 DimensionMismatchException, 121 NotStrictlyPositiveException { 122 validateParameters(param); 123 return Logistic.value(param[1] - x, param[0], 124 param[2], param[3], 125 param[4], 1 / param[5]); 126 } 127 128 /** 129 * Computes the value of the gradient at {@code x}. 130 * The components of the gradient vector are the partial 131 * derivatives of the function with respect to each of the 132 * <em>parameters</em>. 133 * 134 * @param x Value at which the gradient must be computed. 135 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 136 * {@code a} and {@code n}. 137 * @return the gradient vector at {@code x}. 138 * @throws NullArgumentException if {@code param} is {@code null}. 139 * @throws DimensionMismatchException if the size of {@code param} is 140 * not 6. 141 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 142 */ 143 public double[] gradient(double x, double ... param) 144 throws NullArgumentException, 145 DimensionMismatchException, 146 NotStrictlyPositiveException { 147 validateParameters(param); 148 149 final double b = param[2]; 150 final double q = param[3]; 151 152 final double mMinusX = param[1] - x; 153 final double oneOverN = 1 / param[5]; 154 final double exp = FastMath.exp(b * mMinusX); 155 final double qExp = q * exp; 156 final double qExp1 = qExp + 1; 157 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN); 158 final double factor2 = -factor1 / qExp1; 159 160 // Components of the gradient. 161 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN); 162 final double gm = factor2 * b * qExp; 163 final double gb = factor2 * mMinusX * qExp; 164 final double gq = factor2 * exp; 165 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN); 166 final double gn = factor1 * Math.log(qExp1) * oneOverN; 167 168 return new double[] { gk, gm, gb, gq, ga, gn }; 169 } 170 171 /** 172 * Validates parameters to ensure they are appropriate for the evaluation of 173 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 174 * methods. 175 * 176 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 177 * {@code a} and {@code n}. 178 * @throws NullArgumentException if {@code param} is {@code null}. 179 * @throws DimensionMismatchException if the size of {@code param} is 180 * not 6. 181 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}. 182 */ 183 private void validateParameters(double[] param) 184 throws NullArgumentException, 185 DimensionMismatchException, 186 NotStrictlyPositiveException { 187 if (param == null) { 188 throw new NullArgumentException(); 189 } 190 if (param.length != 6) { 191 throw new DimensionMismatchException(param.length, 6); 192 } 193 if (param[5] <= 0) { 194 throw new NotStrictlyPositiveException(param[5]); 195 } 196 } 197 } 198 199 /** 200 * @param mMinusX {@code m - x}. 201 * @param k {@code k}. 202 * @param b {@code b}. 203 * @param q {@code q}. 204 * @param a {@code a}. 205 * @param oneOverN {@code 1 / n}. 206 * @return the value of the function. 207 */ 208 private static double value(double mMinusX, 209 double k, 210 double b, 211 double q, 212 double a, 213 double oneOverN) { 214 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN); 215 } 216 217 /** {@inheritDoc} 218 * @since 3.1 219 */ 220 public DerivativeStructure value(final DerivativeStructure t) { 221 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a); 222 } 223 224 }