001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.analysis.function; 019 020 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 021 import org.apache.commons.math3.analysis.FunctionUtils; 022 import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 023 import org.apache.commons.math3.analysis.UnivariateFunction; 024 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 026 import org.apache.commons.math3.exception.DimensionMismatchException; 027 import org.apache.commons.math3.exception.NullArgumentException; 028 import org.apache.commons.math3.exception.OutOfRangeException; 029 import org.apache.commons.math3.util.FastMath; 030 031 /** 032 * <a href="http://en.wikipedia.org/wiki/Logit"> 033 * Logit</a> function. 034 * It is the inverse of the {@link Sigmoid sigmoid} function. 035 * 036 * @since 3.0 037 * @version $Id: Logit.java 1391927 2012-09-30 00:03:30Z erans $ 038 */ 039 public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 040 /** Lower bound. */ 041 private final double lo; 042 /** Higher bound. */ 043 private final double hi; 044 045 /** 046 * Usual logit function, where the lower bound is 0 and the higher 047 * bound is 1. 048 */ 049 public Logit() { 050 this(0, 1); 051 } 052 053 /** 054 * Logit function. 055 * 056 * @param lo Lower bound of the function domain. 057 * @param hi Higher bound of the function domain. 058 */ 059 public Logit(double lo, 060 double hi) { 061 this.lo = lo; 062 this.hi = hi; 063 } 064 065 /** {@inheritDoc} */ 066 public double value(double x) 067 throws OutOfRangeException { 068 return value(x, lo, hi); 069 } 070 071 /** {@inheritDoc} 072 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 073 */ 074 @Deprecated 075 public UnivariateFunction derivative() { 076 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 077 } 078 079 /** 080 * Parametric function where the input array contains the parameters of 081 * the logit function, ordered as follows: 082 * <ul> 083 * <li>Lower bound</li> 084 * <li>Higher bound</li> 085 * </ul> 086 */ 087 public static class Parametric implements ParametricUnivariateFunction { 088 /** 089 * Computes the value of the logit at {@code x}. 090 * 091 * @param x Value for which the function must be computed. 092 * @param param Values of lower bound and higher bounds. 093 * @return the value of the function. 094 * @throws NullArgumentException if {@code param} is {@code null}. 095 * @throws DimensionMismatchException if the size of {@code param} is 096 * not 2. 097 */ 098 public double value(double x, double ... param) 099 throws NullArgumentException, 100 DimensionMismatchException { 101 validateParameters(param); 102 return Logit.value(x, param[0], param[1]); 103 } 104 105 /** 106 * Computes the value of the gradient at {@code x}. 107 * The components of the gradient vector are the partial 108 * derivatives of the function with respect to each of the 109 * <em>parameters</em> (lower bound and higher bound). 110 * 111 * @param x Value at which the gradient must be computed. 112 * @param param Values for lower and higher bounds. 113 * @return the gradient vector at {@code x}. 114 * @throws NullArgumentException if {@code param} is {@code null}. 115 * @throws DimensionMismatchException if the size of {@code param} is 116 * not 2. 117 */ 118 public double[] gradient(double x, double ... param) 119 throws NullArgumentException, 120 DimensionMismatchException { 121 validateParameters(param); 122 123 final double lo = param[0]; 124 final double hi = param[1]; 125 126 return new double[] { 1 / (lo - x), 1 / (hi - x) }; 127 } 128 129 /** 130 * Validates parameters to ensure they are appropriate for the evaluation of 131 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 132 * methods. 133 * 134 * @param param Values for lower and higher bounds. 135 * @throws NullArgumentException if {@code param} is {@code null}. 136 * @throws DimensionMismatchException if the size of {@code param} is 137 * not 2. 138 */ 139 private void validateParameters(double[] param) 140 throws NullArgumentException, 141 DimensionMismatchException { 142 if (param == null) { 143 throw new NullArgumentException(); 144 } 145 if (param.length != 2) { 146 throw new DimensionMismatchException(param.length, 2); 147 } 148 } 149 } 150 151 /** 152 * @param x Value at which to compute the logit. 153 * @param lo Lower bound. 154 * @param hi Higher bound. 155 * @return the value of the logit function at {@code x}. 156 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}. 157 */ 158 private static double value(double x, 159 double lo, 160 double hi) 161 throws OutOfRangeException { 162 if (x < lo || x > hi) { 163 throw new OutOfRangeException(x, lo, hi); 164 } 165 return FastMath.log((x - lo) / (hi - x)); 166 } 167 168 /** {@inheritDoc} 169 * @since 3.1 170 * @exception OutOfRangeException if parameter is outside of function domain 171 */ 172 public DerivativeStructure value(final DerivativeStructure t) 173 throws OutOfRangeException { 174 final double x = t.getValue(); 175 if (x < lo || x > hi) { 176 throw new OutOfRangeException(x, lo, hi); 177 } 178 double[] f = new double[t.getOrder() + 1]; 179 180 // function value 181 f[0] = FastMath.log((x - lo) / (hi - x)); 182 183 if (Double.isInfinite(f[0])) { 184 185 if (f.length > 1) { 186 f[1] = Double.POSITIVE_INFINITY; 187 } 188 // fill the array with infinities 189 // (for x close to lo the signs will flip between -inf and +inf, 190 // for x close to hi the signs will always be +inf) 191 // this is probably overkill, since the call to compose at the end 192 // of the method will transform most infinities into NaN ... 193 for (int i = 2; i < f.length; ++i) { 194 f[i] = f[i - 2]; 195 } 196 197 } else { 198 199 // function derivatives 200 final double invL = 1.0 / (x - lo); 201 double xL = invL; 202 final double invH = 1.0 / (hi - x); 203 double xH = invH; 204 for (int i = 1; i < f.length; ++i) { 205 f[i] = xL + xH; 206 xL *= -i * invL; 207 xH *= i * invH; 208 } 209 } 210 211 return t.compose(f); 212 } 213 }