8225189: Multiple JNI calls within critical region in ZIP Library
authorlancea
Tue, 11 Jun 2019 13:04:36 -0400
changeset 55331 dbf5cda9843d
parent 55330 1fef7d9309a9
child 55332 f492567244ab
8225189: Multiple JNI calls within critical region in ZIP Library Reviewed-by: alanb
src/java.base/share/native/libzip/Deflater.c
src/java.base/share/native/libzip/Inflater.c
--- a/src/java.base/share/native/libzip/Deflater.c	Tue Jun 11 15:46:26 2019 +0100
+++ b/src/java.base/share/native/libzip/Deflater.c	Tue Jun 11 13:04:36 2019 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1997, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -76,9 +76,8 @@
     }
 }
 
-static void doSetDictionary(JNIEnv *env, jlong addr, jbyte *buf, jint len)
+static void checkSetDictionaryResult(JNIEnv *env, jlong addr, jint res)
 {
-    int res = deflateSetDictionary(jlong_to_ptr(addr), (Bytef *) buf, len);
     switch (res) {
     case Z_OK:
         break;
@@ -95,30 +94,33 @@
 Java_java_util_zip_Deflater_setDictionary(JNIEnv *env, jclass cls, jlong addr,
                                           jbyteArray b, jint off, jint len)
 {
-    jbyte *buf = (*env)->GetPrimitiveArrayCritical(env, b, 0);
+    int res;
+    Bytef *buf = (*env)->GetPrimitiveArrayCritical(env, b, 0);
     if (buf == NULL) /* out of memory */
         return;
-    doSetDictionary(env, addr, buf + off, len);
+    res = deflateSetDictionary(jlong_to_ptr(addr), buf, len);
     (*env)->ReleasePrimitiveArrayCritical(env, b, buf, 0);
+    checkSetDictionaryResult(env, addr, res);
 }
 
 JNIEXPORT void JNICALL
 Java_java_util_zip_Deflater_setDictionaryBuffer(JNIEnv *env, jclass cls, jlong addr,
                                           jlong bufferAddr, jint len)
 {
-    jbyte *buf = jlong_to_ptr(bufferAddr);
-    doSetDictionary(env, addr, buf, len);
+    int res;
+    Bytef *buf = jlong_to_ptr(bufferAddr);
+    res = deflateSetDictionary(jlong_to_ptr(addr), buf, len);
+    checkSetDictionaryResult(env, addr, res);
 }
 
-static jlong doDeflate(JNIEnv *env, jobject this, jlong addr,
+static jint doDeflate(JNIEnv *env, jlong addr,
                        jbyte *input, jint inputLen,
                        jbyte *output, jint outputLen,
                        jint flush, jint params)
 {
     z_stream *strm = jlong_to_ptr(addr);
-    jint inputUsed = 0, outputUsed = 0;
-    int finished = 0;
     int setParams = params & 1;
+    int res;
 
     strm->next_in  = (Bytef *) input;
     strm->next_out = (Bytef *) output;
@@ -128,7 +130,24 @@
     if (setParams) {
         int strategy = (params >> 1) & 3;
         int level = params >> 3;
-        int res = deflateParams(strm, level, strategy);
+        res = deflateParams(strm, level, strategy);
+    } else {
+        res = deflate(strm, flush);
+    }
+    return res;
+}
+
+static jlong checkDeflateStatus(JNIEnv *env, jlong addr,
+                        jint inputLen,
+                        jint outputLen,
+                        jint params, int res)
+{
+    z_stream *strm = jlong_to_ptr(addr);
+    jint inputUsed = 0, outputUsed = 0;
+    int finished = 0;
+    int setParams = params & 1;
+
+    if (setParams) {
         switch (res) {
         case Z_OK:
             setParams = 0;
@@ -142,7 +161,6 @@
             return 0;
         }
     } else {
-        int res = deflate(strm, flush);
         switch (res) {
         case Z_STREAM_END:
             finished = 1;
@@ -169,6 +187,8 @@
     jbyte *input = (*env)->GetPrimitiveArrayCritical(env, inputArray, 0);
     jbyte *output;
     jlong retVal;
+    jint res;
+
     if (input == NULL) {
         if (inputLen != 0 && (*env)->ExceptionOccurred(env) == NULL)
             JNU_ThrowOutOfMemoryError(env, 0);
@@ -182,14 +202,13 @@
         return 0L;
     }
 
-    retVal = doDeflate(env, this, addr,
-            input + inputOff, inputLen,
-            output + outputOff, outputLen,
-            flush, params);
+     res = doDeflate(env, addr, input + inputOff, inputLen,output + outputOff,
+                     outputLen, flush, params);
 
     (*env)->ReleasePrimitiveArrayCritical(env, outputArray, output, 0);
     (*env)->ReleasePrimitiveArrayCritical(env, inputArray, input, 0);
 
+    retVal = checkDeflateStatus(env, addr, inputLen, outputLen, params, res);
     return retVal;
 }
 
@@ -203,6 +222,7 @@
     jbyte *input = (*env)->GetPrimitiveArrayCritical(env, inputArray, 0);
     jbyte *output;
     jlong retVal;
+    jint res;
     if (input == NULL) {
         if (inputLen != 0 && (*env)->ExceptionOccurred(env) == NULL)
             JNU_ThrowOutOfMemoryError(env, 0);
@@ -210,13 +230,12 @@
     }
     output = jlong_to_ptr(outputBuffer);
 
-    retVal = doDeflate(env, this, addr,
-            input + inputOff, inputLen,
-            output, outputLen,
-            flush, params);
+    res = doDeflate(env, addr, input + inputOff, inputLen, output, outputLen,
+                    flush, params);
 
     (*env)->ReleasePrimitiveArrayCritical(env, inputArray, input, 0);
 
+    retVal = checkDeflateStatus(env, addr, inputLen, outputLen, params, res);
     return retVal;
 }
 
@@ -229,19 +248,18 @@
     jbyte *input = jlong_to_ptr(inputBuffer);
     jbyte *output = (*env)->GetPrimitiveArrayCritical(env, outputArray, 0);
     jlong retVal;
+    jint res;
     if (output == NULL) {
         if (outputLen != 0 && (*env)->ExceptionOccurred(env) == NULL)
             JNU_ThrowOutOfMemoryError(env, 0);
         return 0L;
     }
 
-    retVal = doDeflate(env, this, addr,
-            input, inputLen,
-            output + outputOff, outputLen,
-            flush, params);
-
+    res = doDeflate(env, addr, input, inputLen, output + outputOff, outputLen,
+                    flush, params);
     (*env)->ReleasePrimitiveArrayCritical(env, outputArray, input, 0);
 
+    retVal = checkDeflateStatus(env, addr, inputLen, outputLen, params, res);
     return retVal;
 }
 
@@ -253,11 +271,12 @@
 {
     jbyte *input = jlong_to_ptr(inputBuffer);
     jbyte *output = jlong_to_ptr(outputBuffer);
+    jlong retVal;
+    jint res;
 
-    return doDeflate(env, this, addr,
-            input, inputLen,
-            output, outputLen,
-            flush, params);
+    res = doDeflate(env, addr, input, inputLen, output, outputLen, flush, params);
+    retVal = checkDeflateStatus(env, addr, inputLen, outputLen, params, res);
+    return retVal;
 }
 
 JNIEXPORT jint JNICALL
--- a/src/java.base/share/native/libzip/Inflater.c	Tue Jun 11 15:46:26 2019 +0100
+++ b/src/java.base/share/native/libzip/Inflater.c	Tue Jun 11 13:04:36 2019 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1997, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -87,9 +87,8 @@
     }
 }
 
-static void doSetDictionary(JNIEnv *env, jlong addr, jbyte *buf, jint len)
+static void checkSetDictionaryResult(JNIEnv *env, jlong addr, int res)
 {
-    int res = inflateSetDictionary(jlong_to_ptr(addr), (Bytef *) buf, len);
     switch (res) {
     case Z_OK:
         break;
@@ -107,30 +106,31 @@
 Java_java_util_zip_Inflater_setDictionary(JNIEnv *env, jclass cls, jlong addr,
                                           jbyteArray b, jint off, jint len)
 {
-    jbyte *buf = (*env)->GetPrimitiveArrayCritical(env, b, 0);
+    jint res;
+    Bytef *buf = (*env)->GetPrimitiveArrayCritical(env, b, 0);
     if (buf == NULL) /* out of memory */
         return;
-    doSetDictionary(env, addr, buf + off, len);
+    res = inflateSetDictionary(jlong_to_ptr(addr), buf + off, len);
     (*env)->ReleasePrimitiveArrayCritical(env, b, buf, 0);
+    checkSetDictionaryResult(env, addr, res);
 }
 
 JNIEXPORT void JNICALL
 Java_java_util_zip_Inflater_setDictionaryBuffer(JNIEnv *env, jclass cls, jlong addr,
                                           jlong bufferAddr, jint len)
 {
-    jbyte *buf = jlong_to_ptr(bufferAddr);
-    doSetDictionary(env, addr, buf, len);
+    jint res;
+    Bytef *buf = jlong_to_ptr(bufferAddr);
+    res = inflateSetDictionary(jlong_to_ptr(addr), buf, len);
+    checkSetDictionaryResult(env, addr, res);
 }
 
-static jlong doInflate(JNIEnv *env, jobject this, jlong addr,
+static jint doInflate(jlong addr,
                        jbyte *input, jint inputLen,
                        jbyte *output, jint outputLen)
 {
+    jint ret;
     z_stream *strm = jlong_to_ptr(addr);
-    jint inputUsed = 0, outputUsed = 0;
-    int finished = 0;
-    int needDict = 0;
-    int ret;
 
     strm->next_in  = (Bytef *) input;
     strm->next_out = (Bytef *) output;
@@ -138,6 +138,16 @@
     strm->avail_out = outputLen;
 
     ret = inflate(strm, Z_PARTIAL_FLUSH);
+    return ret;
+}
+
+static jlong checkInflateStatus(JNIEnv *env, jobject this, jlong addr,
+                        jint inputLen, jint outputLen, jint ret )
+{
+    z_stream *strm = jlong_to_ptr(addr);
+    jint inputUsed = 0, outputUsed = 0;
+    int finished = 0;
+    int needDict = 0;
 
     switch (ret) {
     case Z_STREAM_END:
@@ -180,6 +190,7 @@
 {
     jbyte *input = (*env)->GetPrimitiveArrayCritical(env, inputArray, 0);
     jbyte *output;
+    jint ret;
     jlong retVal;
 
     if (input == NULL) {
@@ -195,13 +206,13 @@
         return 0L;
     }
 
-    retVal = doInflate(env, this, addr,
-            input + inputOff, inputLen,
-            output + outputOff, outputLen);
+    ret = doInflate(addr, input + inputOff, inputLen, output + outputOff,
+                    outputLen);
 
     (*env)->ReleasePrimitiveArrayCritical(env, outputArray, output, 0);
     (*env)->ReleasePrimitiveArrayCritical(env, inputArray, input, 0);
 
+    retVal = checkInflateStatus(env, this, addr, inputLen, outputLen, ret );
     return retVal;
 }
 
@@ -212,6 +223,7 @@
 {
     jbyte *input = (*env)->GetPrimitiveArrayCritical(env, inputArray, 0);
     jbyte *output;
+    jint ret;
     jlong retVal;
 
     if (input == NULL) {
@@ -221,11 +233,10 @@
     }
     output = jlong_to_ptr(outputBuffer);
 
-    retVal = doInflate(env, this, addr,
-            input + inputOff, inputLen,
-            output, outputLen);
+    ret = doInflate(addr, input + inputOff, inputLen, output, outputLen);
 
     (*env)->ReleasePrimitiveArrayCritical(env, inputArray, input, 0);
+    retVal = checkInflateStatus(env, this, addr, inputLen, outputLen, ret );
 
     return retVal;
 }
@@ -237,6 +248,7 @@
 {
     jbyte *input = jlong_to_ptr(inputBuffer);
     jbyte *output = (*env)->GetPrimitiveArrayCritical(env, outputArray, 0);
+    jint ret;
     jlong retVal;
 
     if (output == NULL) {
@@ -245,11 +257,10 @@
         return 0L;
     }
 
-    retVal = doInflate(env, this, addr,
-            input, inputLen,
-            output + outputOff, outputLen);
+    ret = doInflate(addr, input, inputLen, output  + outputOff, outputLen);
 
     (*env)->ReleasePrimitiveArrayCritical(env, outputArray, output, 0);
+    retVal = checkInflateStatus(env, this, addr, inputLen, outputLen, ret );
 
     return retVal;
 }
@@ -261,10 +272,12 @@
 {
     jbyte *input = jlong_to_ptr(inputBuffer);
     jbyte *output = jlong_to_ptr(outputBuffer);
+    jint ret;
+    jlong retVal;
 
-    return doInflate(env, this, addr,
-            input, inputLen,
-            output, outputLen);
+    ret = doInflate(addr, input, inputLen, output, outputLen);
+    retVal = checkInflateStatus(env, this, addr, inputLen, outputLen, ret);
+    return retVal;
 }
 
 JNIEXPORT jint JNICALL