001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math3.analysis.function; 019 020 import java.util.Arrays; 021 022 import org.apache.commons.math3.analysis.FunctionUtils; 023 import org.apache.commons.math3.analysis.UnivariateFunction; 024 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; 025 import org.apache.commons.math3.analysis.ParametricUnivariateFunction; 026 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure; 027 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction; 028 import org.apache.commons.math3.exception.NullArgumentException; 029 import org.apache.commons.math3.exception.DimensionMismatchException; 030 import org.apache.commons.math3.util.FastMath; 031 032 /** 033 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function"> 034 * Sigmoid</a> function. 035 * It is the inverse of the {@link Logit logit} function. 036 * A more flexible version, the generalised logistic, is implemented 037 * by the {@link Logistic} class. 038 * 039 * @since 3.0 040 * @version $Id: Sigmoid.java 1391927 2012-09-30 00:03:30Z erans $ 041 */ 042 public class Sigmoid implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction { 043 /** Lower asymptote. */ 044 private final double lo; 045 /** Higher asymptote. */ 046 private final double hi; 047 048 /** 049 * Usual sigmoid function, where the lower asymptote is 0 and the higher 050 * asymptote is 1. 051 */ 052 public Sigmoid() { 053 this(0, 1); 054 } 055 056 /** 057 * Sigmoid function. 058 * 059 * @param lo Lower asymptote. 060 * @param hi Higher asymptote. 061 */ 062 public Sigmoid(double lo, 063 double hi) { 064 this.lo = lo; 065 this.hi = hi; 066 } 067 068 /** {@inheritDoc} 069 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)} 070 */ 071 @Deprecated 072 public UnivariateFunction derivative() { 073 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative(); 074 } 075 076 /** {@inheritDoc} */ 077 public double value(double x) { 078 return value(x, lo, hi); 079 } 080 081 /** 082 * Parametric function where the input array contains the parameters of 083 * the logit function, ordered as follows: 084 * <ul> 085 * <li>Lower asymptote</li> 086 * <li>Higher asymptote</li> 087 * </ul> 088 */ 089 public static class Parametric implements ParametricUnivariateFunction { 090 /** 091 * Computes the value of the sigmoid at {@code x}. 092 * 093 * @param x Value for which the function must be computed. 094 * @param param Values of lower asymptote and higher asymptote. 095 * @return the value of the function. 096 * @throws NullArgumentException if {@code param} is {@code null}. 097 * @throws DimensionMismatchException if the size of {@code param} is 098 * not 2. 099 */ 100 public double value(double x, double ... param) 101 throws NullArgumentException, 102 DimensionMismatchException { 103 validateParameters(param); 104 return Sigmoid.value(x, param[0], param[1]); 105 } 106 107 /** 108 * Computes the value of the gradient at {@code x}. 109 * The components of the gradient vector are the partial 110 * derivatives of the function with respect to each of the 111 * <em>parameters</em> (lower asymptote and higher asymptote). 112 * 113 * @param x Value at which the gradient must be computed. 114 * @param param Values for lower asymptote and higher asymptote. 115 * @return the gradient vector at {@code x}. 116 * @throws NullArgumentException if {@code param} is {@code null}. 117 * @throws DimensionMismatchException if the size of {@code param} is 118 * not 2. 119 */ 120 public double[] gradient(double x, double ... param) 121 throws NullArgumentException, 122 DimensionMismatchException { 123 validateParameters(param); 124 125 final double invExp1 = 1 / (1 + FastMath.exp(-x)); 126 127 return new double[] { 1 - invExp1, invExp1 }; 128 } 129 130 /** 131 * Validates parameters to ensure they are appropriate for the evaluation of 132 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 133 * methods. 134 * 135 * @param param Values for lower and higher asymptotes. 136 * @throws NullArgumentException if {@code param} is {@code null}. 137 * @throws DimensionMismatchException if the size of {@code param} is 138 * not 2. 139 */ 140 private void validateParameters(double[] param) 141 throws NullArgumentException, 142 DimensionMismatchException { 143 if (param == null) { 144 throw new NullArgumentException(); 145 } 146 if (param.length != 2) { 147 throw new DimensionMismatchException(param.length, 2); 148 } 149 } 150 } 151 152 /** 153 * @param x Value at which to compute the sigmoid. 154 * @param lo Lower asymptote. 155 * @param hi Higher asymptote. 156 * @return the value of the sigmoid function at {@code x}. 157 */ 158 private static double value(double x, 159 double lo, 160 double hi) { 161 return lo + (hi - lo) / (1 + FastMath.exp(-x)); 162 } 163 164 /** {@inheritDoc} 165 * @since 3.1 166 */ 167 public DerivativeStructure value(final DerivativeStructure t) { 168 169 double[] f = new double[t.getOrder() + 1]; 170 final double exp = FastMath.exp(-t.getValue()); 171 if (Double.isInfinite(exp)) { 172 173 // special handling near lower boundary, to avoid NaN 174 f[0] = lo; 175 Arrays.fill(f, 1, f.length, 0.0); 176 177 } else { 178 179 // the nth order derivative of sigmoid has the form: 180 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1) 181 // where P_n(t) is a degree n polynomial with normalized higher term 182 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t... 183 // the general recurrence relation for P_n is: 184 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t) 185 final double[] p = new double[f.length]; 186 187 final double inv = 1 / (1 + exp); 188 double coeff = hi - lo; 189 for (int n = 0; n < f.length; ++n) { 190 191 // update and evaluate polynomial P_n(t) 192 double v = 0; 193 p[n] = 1; 194 for (int k = n; k >= 0; --k) { 195 v = v * exp + p[k]; 196 if (k > 1) { 197 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1]; 198 } else { 199 p[0] = 0; 200 } 201 } 202 203 coeff *= inv; 204 f[n] = coeff * v; 205 206 } 207 208 // fix function value 209 f[0] += lo; 210 211 } 212 213 return t.compose(f); 214 215 } 216 217 }