8066842: java.math.BigDecimal.divide(BigDecimal, RoundingMode) produces incorrect result
authorbpb
Wed, 11 Feb 2015 17:20:39 -0800
changeset 28869 9725237b107d
parent 28868 445be2b2eae8
child 28870 cc64f71a38ec
8066842: java.math.BigDecimal.divide(BigDecimal, RoundingMode) produces incorrect result Summary: Replace divWord() with non-truncating alternatives Reviewed-by: psandoz, darcy
jdk/src/java.base/share/classes/java/math/BigDecimal.java
jdk/test/java/math/BigDecimal/DivideTests.java
--- a/jdk/src/java.base/share/classes/java/math/BigDecimal.java	Tue Feb 10 10:44:38 2015 +0100
+++ b/jdk/src/java.base/share/classes/java/math/BigDecimal.java	Wed Feb 11 17:20:39 2015 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -4814,41 +4814,61 @@
         if (dividendHi >= divisor) {
             return null;
         }
+
         final int shift = Long.numberOfLeadingZeros(divisor);
         divisor <<= shift;
 
         final long v1 = divisor >>> 32;
         final long v0 = divisor & LONG_MASK;
 
-        long q1, q0;
-        long r_tmp;
-
         long tmp = dividendLo << shift;
         long u1 = tmp >>> 32;
         long u0 = tmp & LONG_MASK;
 
         tmp = (dividendHi << shift) | (dividendLo >>> 64 - shift);
         long u2 = tmp & LONG_MASK;
-        tmp = divWord(tmp,v1);
-        q1 = tmp & LONG_MASK;
-        r_tmp = tmp >>> 32;
+        long q1, r_tmp;
+        if (v1 == 1) {
+            q1 = tmp;
+            r_tmp = 0;
+        } else if (tmp >= 0) {
+            q1 = tmp / v1;
+            r_tmp = tmp - q1 * v1;
+        } else {
+            long[] rq = divRemNegativeLong(tmp, v1);
+            q1 = rq[1];
+            r_tmp = rq[0];
+        }
+
         while(q1 >= DIV_NUM_BASE || unsignedLongCompare(q1*v0, make64(r_tmp, u1))) {
             q1--;
             r_tmp += v1;
             if (r_tmp >= DIV_NUM_BASE)
                 break;
         }
+
         tmp = mulsub(u2,u1,v1,v0,q1);
         u1 = tmp & LONG_MASK;
-        tmp = divWord(tmp,v1);
-        q0 = tmp & LONG_MASK;
-        r_tmp = tmp >>> 32;
+        long q0;
+        if (v1 == 1) {
+            q0 = tmp;
+            r_tmp = 0;
+        } else if (tmp >= 0) {
+            q0 = tmp / v1;
+            r_tmp = tmp - q0 * v1;
+        } else {
+            long[] rq = divRemNegativeLong(tmp, v1);
+            q0 = rq[1];
+            r_tmp = rq[0];
+        }
+
         while(q0 >= DIV_NUM_BASE || unsignedLongCompare(q0*v0,make64(r_tmp,u0))) {
             q0--;
             r_tmp += v1;
             if (r_tmp >= DIV_NUM_BASE)
                 break;
         }
+
         if((int)q1 < 0) {
             // result (which is positive and unsigned here)
             // can't fit into long due to sign bit is used for value
@@ -4871,10 +4891,13 @@
                 }
             }
         }
+
         long q = make64(q1,q0);
         q*=sign;
+
         if (roundingMode == ROUND_DOWN && scale == preferredScale)
             return valueOf(q, scale);
+
         long r = mulsub(u1, u0, v1, v0, q0) >>> shift;
         if (r != 0) {
             boolean increment = needIncrement(divisor >>> shift, roundingMode, sign, q, r);
@@ -4917,28 +4940,35 @@
         }
     }
 
-    private static long divWord(long n, long dLong) {
-        long r;
-        long q;
-        if (dLong == 1) {
-            q = (int)n;
-            return (q & LONG_MASK);
-        }
+    /**
+     * Calculate the quotient and remainder of dividing a negative long by
+     * another long.
+     *
+     * @param n the numerator; must be negative
+     * @param d the denominator; must not be unity
+     * @return a two-element {@long} array with the remainder and quotient in
+     *         the initial and final elements, respectively
+     */
+    private static long[] divRemNegativeLong(long n, long d) {
+        assert n < 0 : "Non-negative numerator " + n;
+        assert d != 1 : "Unity denominator";
+
         // Approximate the quotient and remainder
-        q = (n >>> 1) / (dLong >>> 1);
-        r = n - q*dLong;
+        long q = (n >>> 1) / (d >>> 1);
+        long r = n - q * d;
 
         // Correct the approximation
         while (r < 0) {
-            r += dLong;
+            r += d;
             q--;
         }
-        while (r >= dLong) {
-            r -= dLong;
+        while (r >= d) {
+            r -= d;
             q++;
         }
-        // n - q*dlong == r && 0 <= r <dLong, hence we're done.
-        return (r << 32) | (q & LONG_MASK);
+
+        // n - q*d == r && 0 <= r < d, hence we're done.
+        return new long[] {r, q};
     }
 
     private static long make64(long hi, long lo) {
--- a/jdk/test/java/math/BigDecimal/DivideTests.java	Tue Feb 10 10:44:38 2015 +0100
+++ b/jdk/test/java/math/BigDecimal/DivideTests.java	Wed Feb 11 17:20:39 2015 -0800
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2003, 2005, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2003, 2015, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -23,7 +23,7 @@
 
 /*
  * @test
- * @bug 4851776 4907265 6177836 6876282
+ * @bug 4851776 4907265 6177836 6876282 8066842
  * @summary Some tests for the divide methods.
  * @author Joseph D. Darcy
  */
@@ -358,6 +358,57 @@
         return failures;
     }
 
+    private static int divideByOneTests() {
+        int failures = 0;
+
+        //problematic divisor: one with scale 17
+        BigDecimal one = BigDecimal.ONE.setScale(17);
+        RoundingMode rounding = RoundingMode.UNNECESSARY;
+
+        long[][] unscaledAndScale = new long[][] {
+            { Long.MAX_VALUE,  17},
+            {-Long.MAX_VALUE,  17},
+            { Long.MAX_VALUE,   0},
+            {-Long.MAX_VALUE,   0},
+            { Long.MAX_VALUE, 100},
+            {-Long.MAX_VALUE, 100}
+        };
+
+        for (long[] uas : unscaledAndScale) {
+            long unscaled = uas[0];
+            int scale = (int)uas[1];
+
+            BigDecimal noRound = null;
+            try {
+                noRound = BigDecimal.valueOf(unscaled, scale).
+                    divide(one, RoundingMode.UNNECESSARY);
+            } catch (ArithmeticException e) {
+                failures++;
+                System.err.println("ArithmeticException for value " + unscaled
+                    + " and scale " + scale + " without rounding");
+            }
+
+            BigDecimal roundDown = null;
+            try {
+                roundDown = BigDecimal.valueOf(unscaled, scale).
+                        divide(one, RoundingMode.DOWN);
+            } catch (ArithmeticException e) {
+                failures++;
+                System.err.println("ArithmeticException for value " + unscaled
+                    + " and scale " + scale + " with rounding down");
+            }
+
+            if (noRound != null && roundDown != null
+                && noRound.compareTo(roundDown) != 0) {
+                failures++;
+                System.err.println("Equality failure for value " + unscaled
+                        + " and scale " + scale);
+            }
+        }
+
+        return failures;
+    }
+
     public static void main(String argv[]) {
         int failures = 0;
 
@@ -366,10 +417,11 @@
         failures += properScaleTests();
         failures += trailingZeroTests();
         failures += scaledRoundedDivideTests();
+        failures += divideByOneTests();
 
         if (failures > 0) {
             throw new RuntimeException("Incurred " + failures +
-                                       " failures while testing exact divide.");
+                                       " failures while testing division.");
         }
     }
 }