001/*
002 * Copyright (C) 2011 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.math;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.math.MathPreconditions.checkNonNegative;
022import static com.google.common.math.MathPreconditions.checkPositive;
023import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
024import static java.math.RoundingMode.CEILING;
025import static java.math.RoundingMode.FLOOR;
026import static java.math.RoundingMode.HALF_EVEN;
027
028import com.google.common.annotations.GwtCompatible;
029import com.google.common.annotations.GwtIncompatible;
030import com.google.common.annotations.VisibleForTesting;
031
032import java.math.BigDecimal;
033import java.math.BigInteger;
034import java.math.RoundingMode;
035import java.util.ArrayList;
036import java.util.List;
037
038/**
039 * A class for arithmetic on values of type {@code BigInteger}.
040 *
041 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
042 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
043 *
044 * <p>Similar functionality for {@code int} and for {@code long} can be found in
045 * {@link IntMath} and {@link LongMath} respectively.
046 *
047 * @author Louis Wasserman
048 * @since 11.0
049 */
050@GwtCompatible(emulated = true)
051public final class BigIntegerMath {
052  /**
053   * Returns {@code true} if {@code x} represents a power of two.
054   */
055  public static boolean isPowerOfTwo(BigInteger x) {
056    checkNotNull(x);
057    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
058  }
059
060  /**
061   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
062   *
063   * @throws IllegalArgumentException if {@code x <= 0}
064   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
065   *         is not a power of two
066   */
067  @SuppressWarnings("fallthrough")
068  public static int log2(BigInteger x, RoundingMode mode) {
069    checkPositive("x", checkNotNull(x));
070    int logFloor = x.bitLength() - 1;
071    switch (mode) {
072      case UNNECESSARY:
073        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
074      case DOWN:
075      case FLOOR:
076        return logFloor;
077
078      case UP:
079      case CEILING:
080        return isPowerOfTwo(x) ? logFloor : logFloor + 1;
081
082      case HALF_DOWN:
083      case HALF_UP:
084      case HALF_EVEN:
085        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
086          BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
087              SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
088          if (x.compareTo(halfPower) <= 0) {
089            return logFloor;
090          } else {
091            return logFloor + 1;
092          }
093        }
094        /*
095         * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
096         *
097         * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
098         * logFloor + 1).
099         */
100        BigInteger x2 = x.pow(2);
101        int logX2Floor = x2.bitLength() - 1;
102        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
103
104      default:
105        throw new AssertionError();
106    }
107  }
108
109  /*
110   * The maximum number of bits in a square root for which we'll precompute an explicit half power
111   * of two. This can be any value, but higher values incur more class load time and linearly
112   * increasing memory consumption.
113   */
114  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
115
116  @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
117      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
118
119  /**
120   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
121   *
122   * @throws IllegalArgumentException if {@code x <= 0}
123   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
124   *         is not a power of ten
125   */
126  @GwtIncompatible("TODO")
127  @SuppressWarnings("fallthrough")
128  public static int log10(BigInteger x, RoundingMode mode) {
129    checkPositive("x", x);
130    if (fitsInLong(x)) {
131      return LongMath.log10(x.longValue(), mode);
132    }
133
134    int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
135    BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
136    int approxCmp = approxPow.compareTo(x);
137
138    /*
139     * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
140     * 10^floor(log10(x)).
141     */
142
143    if (approxCmp > 0) {
144      /*
145       * The code is written so that even completely incorrect approximations will still yield the
146       * correct answer eventually, but in practice this branch should almost never be entered,
147       * and even then the loop should not run more than once.
148       */
149      do {
150        approxLog10--;
151        approxPow = approxPow.divide(BigInteger.TEN);
152        approxCmp = approxPow.compareTo(x);
153      } while (approxCmp > 0);
154    } else {
155      BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
156      int nextCmp = nextPow.compareTo(x);
157      while (nextCmp <= 0) {
158        approxLog10++;
159        approxPow = nextPow;
160        approxCmp = nextCmp;
161        nextPow = BigInteger.TEN.multiply(approxPow);
162        nextCmp = nextPow.compareTo(x);
163      }
164    }
165
166    int floorLog = approxLog10;
167    BigInteger floorPow = approxPow;
168    int floorCmp = approxCmp;
169
170    switch (mode) {
171      case UNNECESSARY:
172        checkRoundingUnnecessary(floorCmp == 0);
173        // fall through
174      case FLOOR:
175      case DOWN:
176        return floorLog;
177
178      case CEILING:
179      case UP:
180        return floorPow.equals(x) ? floorLog : floorLog + 1;
181
182      case HALF_DOWN:
183      case HALF_UP:
184      case HALF_EVEN:
185        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
186        BigInteger x2 = x.pow(2);
187        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
188        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
189      default:
190        throw new AssertionError();
191    }
192  }
193
194  private static final double LN_10 = Math.log(10);
195  private static final double LN_2 = Math.log(2);
196
197  /**
198   * Returns the square root of {@code x}, rounded with the specified rounding mode.
199   *
200   * @throws IllegalArgumentException if {@code x < 0}
201   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
202   *         {@code sqrt(x)} is not an integer
203   */
204  @GwtIncompatible("TODO")
205  @SuppressWarnings("fallthrough")
206  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
207    checkNonNegative("x", x);
208    if (fitsInLong(x)) {
209      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
210    }
211    BigInteger sqrtFloor = sqrtFloor(x);
212    switch (mode) {
213      case UNNECESSARY:
214        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
215      case FLOOR:
216      case DOWN:
217        return sqrtFloor;
218      case CEILING:
219      case UP:
220        return sqrtFloor.pow(2).equals(x) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
221      case HALF_DOWN:
222      case HALF_UP:
223      case HALF_EVEN:
224        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
225        /*
226         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
227         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
228         * halfSquare.
229         */
230        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
231      default:
232        throw new AssertionError();
233    }
234  }
235
236  @GwtIncompatible("TODO")
237  private static BigInteger sqrtFloor(BigInteger x) {
238    /*
239     * Adapted from Hacker's Delight, Figure 11-1.
240     *
241     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
242     * then we can get a double approximation of the square root. Then, we iteratively improve this
243     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
244     * This iteration has the following two properties:
245     *
246     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
247     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
248     * and the arithmetic mean is always higher than the geometric mean.
249     *
250     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
251     * with each iteration, so this algorithm takes O(log(digits)) iterations.
252     *
253     * We start out with a double-precision approximation, which may be higher or lower than the
254     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
255     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
256     */
257    BigInteger sqrt0;
258    int log2 = log2(x, FLOOR);
259    if(log2 < Double.MAX_EXPONENT) {
260      sqrt0 = sqrtApproxWithDoubles(x);
261    } else {
262      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
263      /*
264       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
265       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
266       */
267      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
268    }
269    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
270    if (sqrt0.equals(sqrt1)) {
271      return sqrt0;
272    }
273    do {
274      sqrt0 = sqrt1;
275      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
276    } while (sqrt1.compareTo(sqrt0) < 0);
277    return sqrt0;
278  }
279
280  @GwtIncompatible("TODO")
281  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
282    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
283  }
284
285  /**
286   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
287   * {@code RoundingMode}.
288   *
289   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
290   *         is not an integer multiple of {@code b}
291   */
292  @GwtIncompatible("TODO")
293  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode){
294    BigDecimal pDec = new BigDecimal(p);
295    BigDecimal qDec = new BigDecimal(q);
296    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
297  }
298
299  /**
300   * Returns {@code n!}, that is, the product of the first {@code n} positive
301   * integers, or {@code 1} if {@code n == 0}.
302   *
303   * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
304   *
305   * <p>This uses an efficient binary recursive algorithm to compute the factorial
306   * with balanced multiplies.  It also removes all the 2s from the intermediate
307   * products (shifting them back in at the end).
308   *
309   * @throws IllegalArgumentException if {@code n < 0}
310   */
311  public static BigInteger factorial(int n) {
312    checkNonNegative("n", n);
313
314    // If the factorial is small enough, just use LongMath to do it.
315    if (n < LongMath.factorials.length) {
316      return BigInteger.valueOf(LongMath.factorials[n]);
317    }
318
319    // Pre-allocate space for our list of intermediate BigIntegers.
320    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
321    ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
322
323    // Start from the pre-computed maximum long factorial.
324    int startingNumber = LongMath.factorials.length;
325    long product = LongMath.factorials[startingNumber - 1];
326    // Strip off 2s from this value.
327    int shift = Long.numberOfTrailingZeros(product);
328    product >>= shift;
329
330    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
331    int productBits = LongMath.log2(product, FLOOR) + 1;
332    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
333    // Check for the next power of two boundary, to save us a CLZ operation.
334    int nextPowerOfTwo = 1 << (bits - 1);
335
336    // Iteratively multiply the longs as big as they can go.
337    for (long num = startingNumber; num <= n; num++) {
338      // Check to see if the floor(log2(num)) + 1 has changed.
339      if ((num & nextPowerOfTwo) != 0) {
340        nextPowerOfTwo <<= 1;
341        bits++;
342      }
343      // Get rid of the 2s in num.
344      int tz = Long.numberOfTrailingZeros(num);
345      long normalizedNum = num >> tz;
346      shift += tz;
347      // Adjust floor(log2(num)) + 1.
348      int normalizedBits = bits - tz;
349      // If it won't fit in a long, then we store off the intermediate product.
350      if (normalizedBits + productBits >= Long.SIZE) {
351        bignums.add(BigInteger.valueOf(product));
352        product = 1;
353        productBits = 0;
354      }
355      product *= normalizedNum;
356      productBits = LongMath.log2(product, FLOOR) + 1;
357    }
358    // Check for leftovers.
359    if (product > 1) {
360      bignums.add(BigInteger.valueOf(product));
361    }
362    // Efficiently multiply all the intermediate products together.
363    return listProduct(bignums).shiftLeft(shift);
364  }
365
366  static BigInteger listProduct(List<BigInteger> nums) {
367    return listProduct(nums, 0, nums.size());
368  }
369
370  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
371    switch (end - start) {
372      case 0:
373        return BigInteger.ONE;
374      case 1:
375        return nums.get(start);
376      case 2:
377        return nums.get(start).multiply(nums.get(start + 1));
378      case 3:
379        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
380      default:
381        // Otherwise, split the list in half and recursively do this.
382        int m = (end + start) >>> 1;
383        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
384    }
385  }
386
387 /**
388   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
389   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
390   *
391   * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
392   *
393   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
394   */
395  public static BigInteger binomial(int n, int k) {
396    checkNonNegative("n", n);
397    checkNonNegative("k", k);
398    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
399    if (k > (n >> 1)) {
400      k = n - k;
401    }
402    if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
403      return BigInteger.valueOf(LongMath.binomial(n, k));
404    }
405
406    BigInteger accum = BigInteger.ONE;
407
408    long numeratorAccum = n;
409    long denominatorAccum = 1;
410
411    int bits = LongMath.log2(n, RoundingMode.CEILING);
412
413    int numeratorBits = bits;
414
415    for (int i = 1; i < k; i++) {
416      int p = n - i;
417      int q = i + 1;
418
419      // log2(p) >= bits - 1, because p >= n/2
420
421      if (numeratorBits + bits >= Long.SIZE - 1) {
422        // The numerator is as big as it can get without risking overflow.
423        // Multiply numeratorAccum / denominatorAccum into accum.
424        accum = accum
425            .multiply(BigInteger.valueOf(numeratorAccum))
426            .divide(BigInteger.valueOf(denominatorAccum));
427        numeratorAccum = p;
428        denominatorAccum = q;
429        numeratorBits = bits;
430      } else {
431        // We can definitely multiply into the long accumulators without overflowing them.
432        numeratorAccum *= p;
433        denominatorAccum *= q;
434        numeratorBits += bits;
435      }
436    }
437    return accum
438        .multiply(BigInteger.valueOf(numeratorAccum))
439        .divide(BigInteger.valueOf(denominatorAccum));
440  }
441
442  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
443  @GwtIncompatible("TODO")
444  static boolean fitsInLong(BigInteger x) {
445    return x.bitLength() <= Long.SIZE - 1;
446  }
447
448  private BigIntegerMath() {}
449}