8032027: Add BigInteger square root methods
authorbpb
Thu, 10 Dec 2015 17:47:26 -0800
changeset 34538 b0705127fbba
parent 34537 ca5ca0e04c96
child 34539 4f2243ba7257
8032027: Add BigInteger square root methods Summary: Add sqrt() and sqrtAndReminder() using Newton iteration Reviewed-by: darcy, lowasser
jdk/src/java.base/share/classes/java/math/BigInteger.java
jdk/src/java.base/share/classes/java/math/MutableBigInteger.java
jdk/test/java/math/BigInteger/BigIntegerTest.java
--- a/jdk/src/java.base/share/classes/java/math/BigInteger.java	Thu Dec 10 15:57:27 2015 -0800
+++ b/jdk/src/java.base/share/classes/java/math/BigInteger.java	Thu Dec 10 17:47:26 2015 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -2410,6 +2410,53 @@
     }
 
     /**
+     * Returns the integer square root of this BigInteger.  The integer square
+     * root of the corresponding mathematical integer {@code n} is the largest
+     * mathematical integer {@code s} such that {@code s*s <= n}.  It is equal
+     * to the value of {@code floor(sqrt(n))}, where {@code sqrt(n)} denotes the
+     * real square root of {@code n} treated as a real.  Note that the integer
+     * square root will be less than the real square root if the latter is not
+     * representable as an integral value.
+     *
+     * @return the integer square root of {@code this}
+     * @throws ArithmeticException if {@code this} is negative.  (The square
+     *         root of a negative integer {@code val} is
+     *         {@code (i * sqrt(-val))} where <i>i</i> is the
+     *         <i>imaginary unit</i> and is equal to
+     *         {@code sqrt(-1)}.)
+     * @since  1.9
+     */
+    public BigInteger sqrt() {
+        if (this.signum < 0) {
+            throw new ArithmeticException("Negative BigInteger");
+        }
+
+        return new MutableBigInteger(this.mag).sqrt().toBigInteger();
+    }
+
+    /**
+     * Returns an array of two BigIntegers containing the integer square root
+     * {@code s} of {@code this} and its remainder {@code this - s*s},
+     * respectively.
+     *
+     * @return an array of two BigIntegers with the integer square root at
+     *         offset 0 and the remainder at offset 1
+     * @throws ArithmeticException if {@code this} is negative.  (The square
+     *         root of a negative integer {@code val} is
+     *         {@code (i * sqrt(-val))} where <i>i</i> is the
+     *         <i>imaginary unit</i> and is equal to
+     *         {@code sqrt(-1)}.)
+     * @see #sqrt()
+     * @since  1.9
+     */
+    public BigInteger[] sqrtAndRemainder() {
+        BigInteger s = sqrt();
+        BigInteger r = this.subtract(s.square());
+        assert r.compareTo(BigInteger.ZERO) >= 0;
+        return new BigInteger[] {s, r};
+    }
+
+    /**
      * Returns a BigInteger whose value is the greatest common divisor of
      * {@code abs(this)} and {@code abs(val)}.  Returns 0 if
      * {@code this == 0 && val == 0}.
--- a/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java	Thu Dec 10 15:57:27 2015 -0800
+++ b/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java	Thu Dec 10 17:47:26 2015 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1999, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1999, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -1867,6 +1867,96 @@
     }
 
     /**
+     * Calculate the integer square root {@code floor(sqrt(this))} where
+     * {@code sqrt(.)} denotes the mathematical square root. The contents of
+     * {@code this} are <b>not</b> changed. The value of {@code this} is assumed
+     * to be non-negative.
+     *
+     * @implNote The implementation is based on the material in Henry S. Warren,
+     * Jr., <i>Hacker's Delight (2nd ed.)</i> (Addison Wesley, 2013), 279-282.
+     *
+     * @throws ArithmeticException if the value returned by {@code bitLength()}
+     * overflows the range of {@code int}.
+     * @return the integer square root of {@code this}
+     * @since 1.9
+     */
+    MutableBigInteger sqrt() {
+        // Special cases.
+        if (this.isZero()) {
+            return new MutableBigInteger(0);
+        } else if (this.value.length == 1
+                && (this.value[0] & LONG_MASK) < 4) { // result is unity
+            return ONE;
+        }
+
+        if (bitLength() <= 63) {
+            // Initial estimate is the square root of the positive long value.
+            long v = new BigInteger(this.value, 1).longValueExact();
+            long xk = (long)Math.floor(Math.sqrt(v));
+
+            // Refine the estimate.
+            do {
+                long xk1 = (xk + v/xk)/2;
+
+                // Terminate when non-decreasing.
+                if (xk1 >= xk) {
+                    return new MutableBigInteger(new int[] {
+                        (int)(xk >>> 32), (int)(xk & LONG_MASK)
+                    });
+                }
+
+                xk = xk1;
+            } while (true);
+        } else {
+            // Set up the initial estimate of the iteration.
+
+            // Obtain the bitLength > 63.
+            int bitLength = (int) this.bitLength();
+            if (bitLength != this.bitLength()) {
+                throw new ArithmeticException("bitLength() integer overflow");
+            }
+
+            // Determine an even valued right shift into positive long range.
+            int shift = bitLength - 63;
+            if (shift % 2 == 1) {
+                shift++;
+            }
+
+            // Shift the value into positive long range.
+            MutableBigInteger xk = new MutableBigInteger(this);
+            xk.rightShift(shift);
+            xk.normalize();
+
+            // Use the square root of the shifted value as an approximation.
+            double d = new BigInteger(xk.value, 1).doubleValue();
+            BigInteger bi = BigInteger.valueOf((long)Math.ceil(Math.sqrt(d)));
+            xk = new MutableBigInteger(bi.mag);
+
+            // Shift the approximate square root back into the original range.
+            xk.leftShift(shift / 2);
+
+            // Refine the estimate.
+            MutableBigInteger xk1 = new MutableBigInteger();
+            do {
+                // xk1 = (xk + n/xk)/2
+                this.divide(xk, xk1, false);
+                xk1.add(xk);
+                xk1.rightShift(1);
+
+                // Terminate when non-decreasing.
+                if (xk1.compare(xk) >= 0) {
+                    return xk;
+                }
+
+                // xk = xk1
+                xk.copyValue(xk1);
+
+                xk1.reset();
+            } while (true);
+        }
+    }
+
+    /**
      * Calculate GCD of this and b. This and b are changed by the computation.
      */
     MutableBigInteger hybridGCD(MutableBigInteger b) {
--- a/jdk/test/java/math/BigInteger/BigIntegerTest.java	Thu Dec 10 15:57:27 2015 -0800
+++ b/jdk/test/java/math/BigInteger/BigIntegerTest.java	Thu Dec 10 17:47:26 2015 -0800
@@ -26,7 +26,7 @@
  * @library /lib/testlibrary/
  * @build jdk.testlibrary.*
  * @run main BigIntegerTest
- * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225 4837946 4026465 8074460 8078672
+ * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225 4837946 4026465 8074460 8078672 8032027
  * @summary tests methods in BigInteger (use -Dseed=X to set PRNG seed)
  * @run main/timeout=400 BigIntegerTest
  * @author madbot
@@ -38,8 +38,15 @@
 import java.io.FileOutputStream;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.util.Random;
+import java.util.function.ToIntFunction;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+import java.util.stream.Stream;
 import jdk.testlibrary.RandomFactory;
 
 /**
@@ -243,6 +250,146 @@
         report("square for " + order + " bits", failCount1);
     }
 
+    private static void printErr(String msg) {
+        System.err.println(msg);
+    }
+
+    private static int checkResult(BigInteger expected, BigInteger actual,
+        String failureMessage) {
+        if (expected.compareTo(actual) != 0) {
+            printErr(failureMessage + " - expected: " + expected
+                + ", actual: " + actual);
+            return 1;
+        }
+        return 0;
+    }
+
+    private static void squareRootSmall() {
+        int failCount = 0;
+
+        // A negative value should cause an exception.
+        BigInteger n = BigInteger.ONE.negate();
+        BigInteger s;
+        try {
+            s = n.sqrt();
+            // If sqrt() does not throw an exception that is a failure.
+            failCount++;
+            printErr("sqrt() of negative number did not throw an exception");
+        } catch (ArithmeticException expected) {
+            // A negative value should cause an exception and is not a failure.
+        }
+
+        // A zero value should return BigInteger.ZERO.
+        failCount += checkResult(BigInteger.ZERO, BigInteger.ZERO.sqrt(),
+            "sqrt(0) != BigInteger.ZERO");
+
+        // 1 <= value < 4 should return BigInteger.ONE.
+        long[] smalls = new long[] {1, 2, 3};
+        for (long small : smalls) {
+            failCount += checkResult(BigInteger.ONE,
+                BigInteger.valueOf(small).sqrt(), "sqrt("+small+") != 1");
+        }
+
+        report("squareRootSmall", failCount);
+    }
+
+    public static void squareRoot() {
+        squareRootSmall();
+
+        ToIntFunction<BigInteger> f = (n) -> {
+            int failCount = 0;
+
+            // square root of n^2 -> n
+            BigInteger n2 = n.pow(2);
+            failCount += checkResult(n, n2.sqrt(), "sqrt() n^2 -> n");
+
+            // square root of n^2 + 1 -> n
+            BigInteger n2up = n2.add(BigInteger.ONE);
+            failCount += checkResult(n, n2up.sqrt(), "sqrt() n^2 + 1 -> n");
+
+            // square root of (n + 1)^2 - 1 -> n
+            BigInteger up =
+                n.add(BigInteger.ONE).pow(2).subtract(BigInteger.ONE);
+            failCount += checkResult(n, up.sqrt(), "sqrt() (n + 1)^2 - 1 -> n");
+
+            // sqrt(n)^2 <= n
+            BigInteger s = n.sqrt();
+            if (s.multiply(s).compareTo(n) > 0) {
+                failCount++;
+                printErr("sqrt(n)^2 > n for n = " + n);
+            }
+
+            // (sqrt(n) + 1)^2 > n
+            if (s.add(BigInteger.ONE).pow(2).compareTo(n) <= 0) {
+                failCount++;
+                printErr("(sqrt(n) + 1)^2 <= n for n = " + n);
+            }
+
+            return failCount;
+        };
+
+        Stream.Builder<BigInteger> sb = Stream.builder();
+        int maxExponent = Double.MAX_EXPONENT + 1;
+        for (int i = 1; i <= maxExponent; i++) {
+            BigInteger p2 = BigInteger.ONE.shiftLeft(i);
+            sb.add(p2.subtract(BigInteger.ONE));
+            sb.add(p2);
+            sb.add(p2.add(BigInteger.ONE));
+        }
+        sb.add((new BigDecimal(Double.MAX_VALUE)).toBigInteger());
+        sb.add((new BigDecimal(Double.MAX_VALUE)).toBigInteger().add(BigInteger.ONE));
+        report("squareRoot for 2^N and 2^N - 1, 1 <= N <= Double.MAX_EXPONENT",
+            sb.build().collect(Collectors.summingInt(f)));
+
+        IntStream ints = random.ints(SIZE, 4, Integer.MAX_VALUE);
+        report("squareRoot for int", ints.mapToObj(x ->
+            BigInteger.valueOf(x)).collect(Collectors.summingInt(f)));
+
+        LongStream longs = random.longs(SIZE, (long)Integer.MAX_VALUE + 1L,
+            Long.MAX_VALUE);
+        report("squareRoot for long", longs.mapToObj(x ->
+            BigInteger.valueOf(x)).collect(Collectors.summingInt(f)));
+
+        DoubleStream doubles = random.doubles(SIZE,
+            (double) Long.MAX_VALUE + 1.0, Math.sqrt(Double.MAX_VALUE));
+        report("squareRoot for double", doubles.mapToObj(x ->
+            BigDecimal.valueOf(x).toBigInteger()).collect(Collectors.summingInt(f)));
+    }
+
+    public static void squareRootAndRemainder() {
+        ToIntFunction<BigInteger> g = (n) -> {
+            int failCount = 0;
+            BigInteger n2 = n.pow(2);
+
+            // square root of n^2 -> n
+            BigInteger[] actual = n2.sqrtAndRemainder();
+            failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]");
+            failCount += checkResult(BigInteger.ZERO, actual[1],
+                "sqrtAndRemainder()[1]");
+
+            // square root of n^2 + 1 -> n
+            BigInteger n2up = n2.add(BigInteger.ONE);
+            actual = n2up.sqrtAndRemainder();
+            failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]");
+            failCount += checkResult(BigInteger.ONE, actual[1],
+                "sqrtAndRemainder()[1]");
+
+            // square root of (n + 1)^2 - 1 -> n
+            BigInteger up =
+                n.add(BigInteger.ONE).pow(2).subtract(BigInteger.ONE);
+            actual = up.sqrtAndRemainder();
+            failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]");
+            BigInteger r = up.subtract(n2);
+            failCount += checkResult(r, actual[1], "sqrtAndRemainder()[1]");
+
+            return failCount;
+        };
+
+        IntStream bits = random.ints(SIZE, 3, Short.MAX_VALUE);
+        report("sqrtAndRemainder", bits.mapToObj(x ->
+            BigInteger.valueOf(x)).collect(Collectors.summingInt(g)));
+    }
+
     public static void arithmetic(int order) {
         int failCount = 0;
 
@@ -1101,6 +1248,9 @@
         square(ORDER_KARATSUBA_SQUARE);
         square(ORDER_TOOM_COOK_SQUARE);
 
+        squareRoot();
+        squareRootAndRemainder();
+
         bitCount();
         bitLength();
         bitOps(order1);