001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.analysis.function;
019    
020    import org.apache.commons.math3.analysis.FunctionUtils;
021    import org.apache.commons.math3.analysis.UnivariateFunction;
022    import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
023    import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
024    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
025    import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026    import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027    import org.apache.commons.math3.exception.NullArgumentException;
028    import org.apache.commons.math3.exception.DimensionMismatchException;
029    import org.apache.commons.math3.util.FastMath;
030    
031    /**
032     * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
033     *  Generalised logistic</a> function.
034     *
035     * @since 3.0
036     * @version $Id: Logistic.java 1391927 2012-09-30 00:03:30Z erans $
037     */
038    public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
039        /** Lower asymptote. */
040        private final double a;
041        /** Upper asymptote. */
042        private final double k;
043        /** Growth rate. */
044        private final double b;
045        /** Parameter that affects near which asymptote maximum growth occurs. */
046        private final double oneOverN;
047        /** Parameter that affects the position of the curve along the ordinate axis. */
048        private final double q;
049        /** Abscissa of maximum growth. */
050        private final double m;
051    
052        /**
053         * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
054         * If {@code b < 0}, value of the function for x going towards -&infin;.
055         * @param m Abscissa of maximum growth.
056         * @param b Growth rate.
057         * @param q Parameter that affects the position of the curve along the
058         * ordinate axis.
059         * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
060         * If {@code b < 0}, value of the function for x going towards +&infin;.
061         * @param n Parameter that affects near which asymptote the maximum
062         * growth occurs.
063         * @throws NotStrictlyPositiveException if {@code n <= 0}.
064         */
065        public Logistic(double k,
066                        double m,
067                        double b,
068                        double q,
069                        double a,
070                        double n)
071            throws NotStrictlyPositiveException {
072            if (n <= 0) {
073                throw new NotStrictlyPositiveException(n);
074            }
075    
076            this.k = k;
077            this.m = m;
078            this.b = b;
079            this.q = q;
080            this.a = a;
081            oneOverN = 1 / n;
082        }
083    
084        /** {@inheritDoc} */
085        public double value(double x) {
086            return value(m - x, k, b, q, a, oneOverN);
087        }
088    
089        /** {@inheritDoc}
090         * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
091         */
092        @Deprecated
093        public UnivariateFunction derivative() {
094            return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
095        }
096    
097        /**
098         * Parametric function where the input array contains the parameters of
099         * the logit function, ordered as follows:
100         * <ul>
101         *  <li>Lower asymptote</li>
102         *  <li>Higher asymptote</li>
103         * </ul>
104         */
105        public static class Parametric implements ParametricUnivariateFunction {
106            /**
107             * Computes the value of the sigmoid at {@code x}.
108             *
109             * @param x Value for which the function must be computed.
110             * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
111             * {@code a} and  {@code n}.
112             * @return the value of the function.
113             * @throws NullArgumentException if {@code param} is {@code null}.
114             * @throws DimensionMismatchException if the size of {@code param} is
115             * not 6.
116             * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
117             */
118            public double value(double x, double ... param)
119                throws NullArgumentException,
120                       DimensionMismatchException,
121                       NotStrictlyPositiveException {
122                validateParameters(param);
123                return Logistic.value(param[1] - x, param[0],
124                                      param[2], param[3],
125                                      param[4], 1 / param[5]);
126            }
127    
128            /**
129             * Computes the value of the gradient at {@code x}.
130             * The components of the gradient vector are the partial
131             * derivatives of the function with respect to each of the
132             * <em>parameters</em>.
133             *
134             * @param x Value at which the gradient must be computed.
135             * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
136             * {@code a} and  {@code n}.
137             * @return the gradient vector at {@code x}.
138             * @throws NullArgumentException if {@code param} is {@code null}.
139             * @throws DimensionMismatchException if the size of {@code param} is
140             * not 6.
141             * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
142             */
143            public double[] gradient(double x, double ... param)
144                throws NullArgumentException,
145                       DimensionMismatchException,
146                       NotStrictlyPositiveException {
147                validateParameters(param);
148    
149                final double b = param[2];
150                final double q = param[3];
151    
152                final double mMinusX = param[1] - x;
153                final double oneOverN = 1 / param[5];
154                final double exp = FastMath.exp(b * mMinusX);
155                final double qExp = q * exp;
156                final double qExp1 = qExp + 1;
157                final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
158                final double factor2 = -factor1 / qExp1;
159    
160                // Components of the gradient.
161                final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
162                final double gm = factor2 * b * qExp;
163                final double gb = factor2 * mMinusX * qExp;
164                final double gq = factor2 * exp;
165                final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
166                final double gn = factor1 * Math.log(qExp1) * oneOverN;
167    
168                return new double[] { gk, gm, gb, gq, ga, gn };
169            }
170    
171            /**
172             * Validates parameters to ensure they are appropriate for the evaluation of
173             * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
174             * methods.
175             *
176             * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
177             * {@code a} and {@code n}.
178             * @throws NullArgumentException if {@code param} is {@code null}.
179             * @throws DimensionMismatchException if the size of {@code param} is
180             * not 6.
181             * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
182             */
183            private void validateParameters(double[] param)
184                throws NullArgumentException,
185                       DimensionMismatchException,
186                       NotStrictlyPositiveException {
187                if (param == null) {
188                    throw new NullArgumentException();
189                }
190                if (param.length != 6) {
191                    throw new DimensionMismatchException(param.length, 6);
192                }
193                if (param[5] <= 0) {
194                    throw new NotStrictlyPositiveException(param[5]);
195                }
196            }
197        }
198    
199        /**
200         * @param mMinusX {@code m - x}.
201         * @param k {@code k}.
202         * @param b {@code b}.
203         * @param q {@code q}.
204         * @param a {@code a}.
205         * @param oneOverN {@code 1 / n}.
206         * @return the value of the function.
207         */
208        private static double value(double mMinusX,
209                                    double k,
210                                    double b,
211                                    double q,
212                                    double a,
213                                    double oneOverN) {
214            return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
215        }
216    
217        /** {@inheritDoc}
218         * @since 3.1
219         */
220        public DerivativeStructure value(final DerivativeStructure t) {
221            return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
222        }
223    
224    }