# HG changeset patch # User bpb # Date 1449798446 28800 # Node ID b0705127fbbab01ad631fde433621b72f239f407 # Parent ca5ca0e04c961b48bf6f7cd638ba6241921d86b2 8032027: Add BigInteger square root methods Summary: Add sqrt() and sqrtAndReminder() using Newton iteration Reviewed-by: darcy, lowasser diff -r ca5ca0e04c96 -r b0705127fbba jdk/src/java.base/share/classes/java/math/BigInteger.java --- a/jdk/src/java.base/share/classes/java/math/BigInteger.java Thu Dec 10 15:57:27 2015 -0800 +++ b/jdk/src/java.base/share/classes/java/math/BigInteger.java Thu Dec 10 17:47:26 2015 -0800 @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2015, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -2410,6 +2410,53 @@ } /** + * Returns the integer square root of this BigInteger. The integer square + * root of the corresponding mathematical integer {@code n} is the largest + * mathematical integer {@code s} such that {@code s*s <= n}. It is equal + * to the value of {@code floor(sqrt(n))}, where {@code sqrt(n)} denotes the + * real square root of {@code n} treated as a real. Note that the integer + * square root will be less than the real square root if the latter is not + * representable as an integral value. + * + * @return the integer square root of {@code this} + * @throws ArithmeticException if {@code this} is negative. (The square + * root of a negative integer {@code val} is + * {@code (i * sqrt(-val))} where i is the + * imaginary unit and is equal to + * {@code sqrt(-1)}.) + * @since 1.9 + */ + public BigInteger sqrt() { + if (this.signum < 0) { + throw new ArithmeticException("Negative BigInteger"); + } + + return new MutableBigInteger(this.mag).sqrt().toBigInteger(); + } + + /** + * Returns an array of two BigIntegers containing the integer square root + * {@code s} of {@code this} and its remainder {@code this - s*s}, + * respectively. + * + * @return an array of two BigIntegers with the integer square root at + * offset 0 and the remainder at offset 1 + * @throws ArithmeticException if {@code this} is negative. (The square + * root of a negative integer {@code val} is + * {@code (i * sqrt(-val))} where i is the + * imaginary unit and is equal to + * {@code sqrt(-1)}.) + * @see #sqrt() + * @since 1.9 + */ + public BigInteger[] sqrtAndRemainder() { + BigInteger s = sqrt(); + BigInteger r = this.subtract(s.square()); + assert r.compareTo(BigInteger.ZERO) >= 0; + return new BigInteger[] {s, r}; + } + + /** * Returns a BigInteger whose value is the greatest common divisor of * {@code abs(this)} and {@code abs(val)}. Returns 0 if * {@code this == 0 && val == 0}. diff -r ca5ca0e04c96 -r b0705127fbba jdk/src/java.base/share/classes/java/math/MutableBigInteger.java --- a/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java Thu Dec 10 15:57:27 2015 -0800 +++ b/jdk/src/java.base/share/classes/java/math/MutableBigInteger.java Thu Dec 10 17:47:26 2015 -0800 @@ -1,5 +1,5 @@ /* - * Copyright (c) 1999, 2013, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1999, 2015, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1867,6 +1867,96 @@ } /** + * Calculate the integer square root {@code floor(sqrt(this))} where + * {@code sqrt(.)} denotes the mathematical square root. The contents of + * {@code this} are not changed. The value of {@code this} is assumed + * to be non-negative. + * + * @implNote The implementation is based on the material in Henry S. Warren, + * Jr., Hacker's Delight (2nd ed.) (Addison Wesley, 2013), 279-282. + * + * @throws ArithmeticException if the value returned by {@code bitLength()} + * overflows the range of {@code int}. + * @return the integer square root of {@code this} + * @since 1.9 + */ + MutableBigInteger sqrt() { + // Special cases. + if (this.isZero()) { + return new MutableBigInteger(0); + } else if (this.value.length == 1 + && (this.value[0] & LONG_MASK) < 4) { // result is unity + return ONE; + } + + if (bitLength() <= 63) { + // Initial estimate is the square root of the positive long value. + long v = new BigInteger(this.value, 1).longValueExact(); + long xk = (long)Math.floor(Math.sqrt(v)); + + // Refine the estimate. + do { + long xk1 = (xk + v/xk)/2; + + // Terminate when non-decreasing. + if (xk1 >= xk) { + return new MutableBigInteger(new int[] { + (int)(xk >>> 32), (int)(xk & LONG_MASK) + }); + } + + xk = xk1; + } while (true); + } else { + // Set up the initial estimate of the iteration. + + // Obtain the bitLength > 63. + int bitLength = (int) this.bitLength(); + if (bitLength != this.bitLength()) { + throw new ArithmeticException("bitLength() integer overflow"); + } + + // Determine an even valued right shift into positive long range. + int shift = bitLength - 63; + if (shift % 2 == 1) { + shift++; + } + + // Shift the value into positive long range. + MutableBigInteger xk = new MutableBigInteger(this); + xk.rightShift(shift); + xk.normalize(); + + // Use the square root of the shifted value as an approximation. + double d = new BigInteger(xk.value, 1).doubleValue(); + BigInteger bi = BigInteger.valueOf((long)Math.ceil(Math.sqrt(d))); + xk = new MutableBigInteger(bi.mag); + + // Shift the approximate square root back into the original range. + xk.leftShift(shift / 2); + + // Refine the estimate. + MutableBigInteger xk1 = new MutableBigInteger(); + do { + // xk1 = (xk + n/xk)/2 + this.divide(xk, xk1, false); + xk1.add(xk); + xk1.rightShift(1); + + // Terminate when non-decreasing. + if (xk1.compare(xk) >= 0) { + return xk; + } + + // xk = xk1 + xk.copyValue(xk1); + + xk1.reset(); + } while (true); + } + } + + /** * Calculate GCD of this and b. This and b are changed by the computation. */ MutableBigInteger hybridGCD(MutableBigInteger b) { diff -r ca5ca0e04c96 -r b0705127fbba jdk/test/java/math/BigInteger/BigIntegerTest.java --- a/jdk/test/java/math/BigInteger/BigIntegerTest.java Thu Dec 10 15:57:27 2015 -0800 +++ b/jdk/test/java/math/BigInteger/BigIntegerTest.java Thu Dec 10 17:47:26 2015 -0800 @@ -26,7 +26,7 @@ * @library /lib/testlibrary/ * @build jdk.testlibrary.* * @run main BigIntegerTest - * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225 4837946 4026465 8074460 8078672 + * @bug 4181191 4161971 4227146 4194389 4823171 4624738 4812225 4837946 4026465 8074460 8078672 8032027 * @summary tests methods in BigInteger (use -Dseed=X to set PRNG seed) * @run main/timeout=400 BigIntegerTest * @author madbot @@ -38,8 +38,15 @@ import java.io.FileOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.math.BigDecimal; import java.math.BigInteger; import java.util.Random; +import java.util.function.ToIntFunction; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; import jdk.testlibrary.RandomFactory; /** @@ -243,6 +250,146 @@ report("square for " + order + " bits", failCount1); } + private static void printErr(String msg) { + System.err.println(msg); + } + + private static int checkResult(BigInteger expected, BigInteger actual, + String failureMessage) { + if (expected.compareTo(actual) != 0) { + printErr(failureMessage + " - expected: " + expected + + ", actual: " + actual); + return 1; + } + return 0; + } + + private static void squareRootSmall() { + int failCount = 0; + + // A negative value should cause an exception. + BigInteger n = BigInteger.ONE.negate(); + BigInteger s; + try { + s = n.sqrt(); + // If sqrt() does not throw an exception that is a failure. + failCount++; + printErr("sqrt() of negative number did not throw an exception"); + } catch (ArithmeticException expected) { + // A negative value should cause an exception and is not a failure. + } + + // A zero value should return BigInteger.ZERO. + failCount += checkResult(BigInteger.ZERO, BigInteger.ZERO.sqrt(), + "sqrt(0) != BigInteger.ZERO"); + + // 1 <= value < 4 should return BigInteger.ONE. + long[] smalls = new long[] {1, 2, 3}; + for (long small : smalls) { + failCount += checkResult(BigInteger.ONE, + BigInteger.valueOf(small).sqrt(), "sqrt("+small+") != 1"); + } + + report("squareRootSmall", failCount); + } + + public static void squareRoot() { + squareRootSmall(); + + ToIntFunction f = (n) -> { + int failCount = 0; + + // square root of n^2 -> n + BigInteger n2 = n.pow(2); + failCount += checkResult(n, n2.sqrt(), "sqrt() n^2 -> n"); + + // square root of n^2 + 1 -> n + BigInteger n2up = n2.add(BigInteger.ONE); + failCount += checkResult(n, n2up.sqrt(), "sqrt() n^2 + 1 -> n"); + + // square root of (n + 1)^2 - 1 -> n + BigInteger up = + n.add(BigInteger.ONE).pow(2).subtract(BigInteger.ONE); + failCount += checkResult(n, up.sqrt(), "sqrt() (n + 1)^2 - 1 -> n"); + + // sqrt(n)^2 <= n + BigInteger s = n.sqrt(); + if (s.multiply(s).compareTo(n) > 0) { + failCount++; + printErr("sqrt(n)^2 > n for n = " + n); + } + + // (sqrt(n) + 1)^2 > n + if (s.add(BigInteger.ONE).pow(2).compareTo(n) <= 0) { + failCount++; + printErr("(sqrt(n) + 1)^2 <= n for n = " + n); + } + + return failCount; + }; + + Stream.Builder sb = Stream.builder(); + int maxExponent = Double.MAX_EXPONENT + 1; + for (int i = 1; i <= maxExponent; i++) { + BigInteger p2 = BigInteger.ONE.shiftLeft(i); + sb.add(p2.subtract(BigInteger.ONE)); + sb.add(p2); + sb.add(p2.add(BigInteger.ONE)); + } + sb.add((new BigDecimal(Double.MAX_VALUE)).toBigInteger()); + sb.add((new BigDecimal(Double.MAX_VALUE)).toBigInteger().add(BigInteger.ONE)); + report("squareRoot for 2^N and 2^N - 1, 1 <= N <= Double.MAX_EXPONENT", + sb.build().collect(Collectors.summingInt(f))); + + IntStream ints = random.ints(SIZE, 4, Integer.MAX_VALUE); + report("squareRoot for int", ints.mapToObj(x -> + BigInteger.valueOf(x)).collect(Collectors.summingInt(f))); + + LongStream longs = random.longs(SIZE, (long)Integer.MAX_VALUE + 1L, + Long.MAX_VALUE); + report("squareRoot for long", longs.mapToObj(x -> + BigInteger.valueOf(x)).collect(Collectors.summingInt(f))); + + DoubleStream doubles = random.doubles(SIZE, + (double) Long.MAX_VALUE + 1.0, Math.sqrt(Double.MAX_VALUE)); + report("squareRoot for double", doubles.mapToObj(x -> + BigDecimal.valueOf(x).toBigInteger()).collect(Collectors.summingInt(f))); + } + + public static void squareRootAndRemainder() { + ToIntFunction g = (n) -> { + int failCount = 0; + BigInteger n2 = n.pow(2); + + // square root of n^2 -> n + BigInteger[] actual = n2.sqrtAndRemainder(); + failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]"); + failCount += checkResult(BigInteger.ZERO, actual[1], + "sqrtAndRemainder()[1]"); + + // square root of n^2 + 1 -> n + BigInteger n2up = n2.add(BigInteger.ONE); + actual = n2up.sqrtAndRemainder(); + failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]"); + failCount += checkResult(BigInteger.ONE, actual[1], + "sqrtAndRemainder()[1]"); + + // square root of (n + 1)^2 - 1 -> n + BigInteger up = + n.add(BigInteger.ONE).pow(2).subtract(BigInteger.ONE); + actual = up.sqrtAndRemainder(); + failCount += checkResult(n, actual[0], "sqrtAndRemainder()[0]"); + BigInteger r = up.subtract(n2); + failCount += checkResult(r, actual[1], "sqrtAndRemainder()[1]"); + + return failCount; + }; + + IntStream bits = random.ints(SIZE, 3, Short.MAX_VALUE); + report("sqrtAndRemainder", bits.mapToObj(x -> + BigInteger.valueOf(x)).collect(Collectors.summingInt(g))); + } + public static void arithmetic(int order) { int failCount = 0; @@ -1101,6 +1248,9 @@ square(ORDER_KARATSUBA_SQUARE); square(ORDER_TOOM_COOK_SQUARE); + squareRoot(); + squareRootAndRemainder(); + bitCount(); bitLength(); bitOps(order1);