001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.analysis.function;
019    
020    import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
021    import org.apache.commons.math3.analysis.FunctionUtils;
022    import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
023    import org.apache.commons.math3.analysis.UnivariateFunction;
024    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
025    import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026    import org.apache.commons.math3.exception.DimensionMismatchException;
027    import org.apache.commons.math3.exception.NullArgumentException;
028    import org.apache.commons.math3.exception.OutOfRangeException;
029    import org.apache.commons.math3.util.FastMath;
030    
031    /**
032     * <a href="http://en.wikipedia.org/wiki/Logit">
033     *  Logit</a> function.
034     * It is the inverse of the {@link Sigmoid sigmoid} function.
035     *
036     * @since 3.0
037     * @version $Id: Logit.java 1391927 2012-09-30 00:03:30Z erans $
038     */
039    public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
040        /** Lower bound. */
041        private final double lo;
042        /** Higher bound. */
043        private final double hi;
044    
045        /**
046         * Usual logit function, where the lower bound is 0 and the higher
047         * bound is 1.
048         */
049        public Logit() {
050            this(0, 1);
051        }
052    
053        /**
054         * Logit function.
055         *
056         * @param lo Lower bound of the function domain.
057         * @param hi Higher bound of the function domain.
058         */
059        public Logit(double lo,
060                     double hi) {
061            this.lo = lo;
062            this.hi = hi;
063        }
064    
065        /** {@inheritDoc} */
066        public double value(double x)
067            throws OutOfRangeException {
068            return value(x, lo, hi);
069        }
070    
071        /** {@inheritDoc}
072         * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
073         */
074        @Deprecated
075        public UnivariateFunction derivative() {
076            return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
077        }
078    
079        /**
080         * Parametric function where the input array contains the parameters of
081         * the logit function, ordered as follows:
082         * <ul>
083         *  <li>Lower bound</li>
084         *  <li>Higher bound</li>
085         * </ul>
086         */
087        public static class Parametric implements ParametricUnivariateFunction {
088            /**
089             * Computes the value of the logit at {@code x}.
090             *
091             * @param x Value for which the function must be computed.
092             * @param param Values of lower bound and higher bounds.
093             * @return the value of the function.
094             * @throws NullArgumentException if {@code param} is {@code null}.
095             * @throws DimensionMismatchException if the size of {@code param} is
096             * not 2.
097             */
098            public double value(double x, double ... param)
099                throws NullArgumentException,
100                       DimensionMismatchException {
101                validateParameters(param);
102                return Logit.value(x, param[0], param[1]);
103            }
104    
105            /**
106             * Computes the value of the gradient at {@code x}.
107             * The components of the gradient vector are the partial
108             * derivatives of the function with respect to each of the
109             * <em>parameters</em> (lower bound and higher bound).
110             *
111             * @param x Value at which the gradient must be computed.
112             * @param param Values for lower and higher bounds.
113             * @return the gradient vector at {@code x}.
114             * @throws NullArgumentException if {@code param} is {@code null}.
115             * @throws DimensionMismatchException if the size of {@code param} is
116             * not 2.
117             */
118            public double[] gradient(double x, double ... param)
119                throws NullArgumentException,
120                       DimensionMismatchException {
121                validateParameters(param);
122    
123                final double lo = param[0];
124                final double hi = param[1];
125    
126                return new double[] { 1 / (lo - x), 1 / (hi - x) };
127            }
128    
129            /**
130             * Validates parameters to ensure they are appropriate for the evaluation of
131             * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
132             * methods.
133             *
134             * @param param Values for lower and higher bounds.
135             * @throws NullArgumentException if {@code param} is {@code null}.
136             * @throws DimensionMismatchException if the size of {@code param} is
137             * not 2.
138             */
139            private void validateParameters(double[] param)
140                throws NullArgumentException,
141                       DimensionMismatchException {
142                if (param == null) {
143                    throw new NullArgumentException();
144                }
145                if (param.length != 2) {
146                    throw new DimensionMismatchException(param.length, 2);
147                }
148            }
149        }
150    
151        /**
152         * @param x Value at which to compute the logit.
153         * @param lo Lower bound.
154         * @param hi Higher bound.
155         * @return the value of the logit function at {@code x}.
156         * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
157         */
158        private static double value(double x,
159                                    double lo,
160                                    double hi)
161            throws OutOfRangeException {
162            if (x < lo || x > hi) {
163                throw new OutOfRangeException(x, lo, hi);
164            }
165            return FastMath.log((x - lo) / (hi - x));
166        }
167    
168        /** {@inheritDoc}
169         * @since 3.1
170         * @exception OutOfRangeException if parameter is outside of function domain
171         */
172        public DerivativeStructure value(final DerivativeStructure t)
173            throws OutOfRangeException {
174            final double x = t.getValue();
175            if (x < lo || x > hi) {
176                throw new OutOfRangeException(x, lo, hi);
177            }
178            double[] f = new double[t.getOrder() + 1];
179    
180            // function value
181            f[0] = FastMath.log((x - lo) / (hi - x));
182    
183            if (Double.isInfinite(f[0])) {
184    
185                if (f.length > 1) {
186                    f[1] = Double.POSITIVE_INFINITY;
187                }
188                // fill the array with infinities
189                // (for x close to lo the signs will flip between -inf and +inf,
190                //  for x close to hi the signs will always be +inf)
191                // this is probably overkill, since the call to compose at the end
192                // of the method will transform most infinities into NaN ...
193                for (int i = 2; i < f.length; ++i) {
194                    f[i] = f[i - 2];
195                }
196    
197            } else {
198    
199                // function derivatives
200                final double invL = 1.0 / (x - lo);
201                double xL = invL;
202                final double invH = 1.0 / (hi - x);
203                double xH = invH;
204                for (int i = 1; i < f.length; ++i) {
205                    f[i] = xL + xH;
206                    xL  *= -i * invL;
207                    xH  *=  i * invH;
208                }
209            }
210    
211            return t.compose(f);
212        }
213    }