8007617: Better validation of images
authorbae
Tue, 19 Feb 2013 11:47:42 +0400
changeset 16880 558619642e45
parent 16879 9d9567cfcb10
child 16881 f105ea43e3aa
8007617: Better validation of images Reviewed-by: prr, mschoene, jgodinez
jdk/src/share/classes/sun/awt/image/ImageRepresentation.java
jdk/src/share/native/sun/awt/image/awt_ImageRep.c
--- a/jdk/src/share/classes/sun/awt/image/ImageRepresentation.java	Fri Feb 08 09:15:01 2013 -0800
+++ b/jdk/src/share/classes/sun/awt/image/ImageRepresentation.java	Tue Feb 19 11:47:42 2013 +0400
@@ -333,10 +333,10 @@
         hints = h;
     }
 
-    private native void setICMpixels(int x, int y, int w, int h, int[] lut,
+    private native boolean setICMpixels(int x, int y, int w, int h, int[] lut,
                                     byte[] pix, int off, int scansize,
                                     IntegerComponentRaster ict);
-    private native int setDiffICM(int x, int y, int w, int h, int[] lut,
+    private native boolean setDiffICM(int x, int y, int w, int h, int[] lut,
                                  int transPix, int numLut, IndexColorModel icm,
                                  byte[] pix, int off, int scansize,
                                  ByteComponentRaster bct, int chanOff);
@@ -426,10 +426,10 @@
                 IndexColorModel icm = (IndexColorModel) model;
                 ByteComponentRaster bct = (ByteComponentRaster) biRaster;
                 int numlut = numSrcLUT;
-                if (setDiffICM(x, y, w, h, srcLUT, srcLUTtransIndex,
+                if (!setDiffICM(x, y, w, h, srcLUT, srcLUTtransIndex,
                                numSrcLUT, icm,
                                pix, off, scansize, bct,
-                               bct.getDataOffset(0)) == 0) {
+                               bct.getDataOffset(0))) {
                     convertToRGB();
                 }
                 else {
@@ -470,9 +470,14 @@
                     if (s_useNative) {
                         // Note that setICMpixels modifies the raster directly
                         // so we must mark it as changed afterwards
-                        setICMpixels(x, y, w, h, srcLUT, pix, off, scansize,
-                                     iraster);
-                        iraster.markDirty();
+                        if (setICMpixels(x, y, w, h, srcLUT, pix, off, scansize,
+                                     iraster))
+                        {
+                            iraster.markDirty();
+                        } else {
+                            abort();
+                            return;
+                        }
                     }
                     else {
                         int[] storage = new int[w*h];
--- a/jdk/src/share/native/sun/awt/image/awt_ImageRep.c	Fri Feb 08 09:15:01 2013 -0800
+++ b/jdk/src/share/native/sun/awt/image/awt_ImageRep.c	Tue Feb 19 11:47:42 2013 +0400
@@ -45,6 +45,53 @@
 #  define TRUE 1
 #endif
 
+#define CHECK_STRIDE(yy, hh, ss)                            \
+    if ((ss) != 0) {                                        \
+        int limit = 0x7fffffff / ((ss) > 0 ? (ss) : -(ss)); \
+        if (limit < (yy) || limit < ((yy) + (hh) - 1)) {    \
+            /* integer oveflow */                           \
+            return JNI_FALSE;                               \
+        }                                                   \
+    }                                                       \
+
+#define CHECK_SRC()                                      \
+    do {                                                 \
+        int pixeloffset;                                 \
+        if (off < 0 || off >= srcDataLength) {           \
+            return JNI_FALSE;                            \
+        }                                                \
+        CHECK_STRIDE(0, h, scansize);                    \
+                                                         \
+        /* check scansize */                             \
+        pixeloffset = scansize * (h - 1);                \
+        if ((w - 1) > (0x7fffffff - pixeloffset)) {      \
+            return JNI_FALSE;                            \
+        }                                                \
+        pixeloffset += (w - 1);                          \
+                                                         \
+        if (off > (0x7fffffff - pixeloffset)) {          \
+            return JNI_FALSE;                            \
+        }                                                \
+    } while (0)                                          \
+
+#define CHECK_DST(xx, yy)                                \
+    do {                                                 \
+        int soffset = (yy) * sStride;                    \
+        int poffset = (xx) * pixelStride;                \
+        if (poffset > (0x7fffffff - soffset)) {          \
+            return JNI_FALSE;                            \
+        }                                                \
+        poffset += soffset;                              \
+        if (dstDataOff > (0x7fffffff - poffset)) {       \
+            return JNI_FALSE;                            \
+        }                                                \
+        poffset += dstDataOff;                           \
+                                                         \
+        if (poffset < 0 || poffset >= dstDataLength) {   \
+            return JNI_FALSE;                            \
+        }                                                \
+    } while (0)                                          \
+
 static jfieldID s_JnumSrcLUTID;
 static jfieldID s_JsrcLUTtransIndexID;
 
@@ -58,7 +105,7 @@
 /*
  * This routine is used to draw ICM pixels into a default color model
  */
-JNIEXPORT void JNICALL
+JNIEXPORT jboolean JNICALL
 Java_sun_awt_image_ImageRepresentation_setICMpixels(JNIEnv *env, jclass cls,
                                                     jint x, jint y, jint w,
                                                     jint h, jintArray jlut,
@@ -67,7 +114,10 @@
                                                     jobject jict)
 {
     unsigned char *srcData = NULL;
+    jint srcDataLength;
     int *dstData;
+    jint dstDataLength;
+    jint dstDataOff;
     int *dstP, *dstyP;
     unsigned char *srcyP, *srcP;
     int *srcLUT = NULL;
@@ -80,12 +130,20 @@
 
     if (JNU_IsNull(env, jlut)) {
         JNU_ThrowNullPointerException(env, "NullPointerException");
-        return;
+        return JNI_FALSE;
     }
 
     if (JNU_IsNull(env, jpix)) {
         JNU_ThrowNullPointerException(env, "NullPointerException");
-        return;
+        return JNI_FALSE;
+    }
+
+    if (x < 0 || w < 1 || (0x7fffffff - x) < w) {
+        return JNI_FALSE;
+    }
+
+    if (y < 0 || h < 1 || (0x7fffffff - y) < h) {
+        return JNI_FALSE;
     }
 
     sStride = (*env)->GetIntField(env, jict, g_ICRscanstrID);
@@ -93,10 +151,47 @@
     joffs = (*env)->GetObjectField(env, jict, g_ICRdataOffsetsID);
     jdata = (*env)->GetObjectField(env, jict, g_ICRdataID);
 
+    if (JNU_IsNull(env, jdata)) {
+        /* no destination buffer */
+        return JNI_FALSE;
+    }
+
+    if (JNU_IsNull(env, joffs) || (*env)->GetArrayLength(env, joffs) < 1) {
+        /* invalid data offstes in raster */
+        return JNI_FALSE;
+    }
+
+    srcDataLength = (*env)->GetArrayLength(env, jpix);
+    dstDataLength = (*env)->GetArrayLength(env, jdata);
+
+    cOffs = (int *) (*env)->GetPrimitiveArrayCritical(env, joffs, NULL);
+    if (cOffs == NULL) {
+        JNU_ThrowNullPointerException(env, "Null channel offset array");
+        return JNI_FALSE;
+    }
+
+    dstDataOff = cOffs[0];
+
+    /* the offset array is not needed anymore and can be released */
+    (*env)->ReleasePrimitiveArrayCritical(env, joffs, cOffs, JNI_ABORT);
+    joffs = NULL;
+    cOffs = NULL;
+
+    /* do basic validation: make sure that offsets for
+    * first pixel and for last pixel are safe to calculate and use */
+    CHECK_STRIDE(y, h, sStride);
+    CHECK_STRIDE(x, w, pixelStride);
+
+    CHECK_DST(x, y);
+    CHECK_DST(x + w -1, y + h - 1);
+
+    /* check source array */
+    CHECK_SRC();
+
     srcLUT = (int *) (*env)->GetPrimitiveArrayCritical(env, jlut, NULL);
     if (srcLUT == NULL) {
         JNU_ThrowNullPointerException(env, "Null IndexColorModel LUT");
-        return;
+        return JNI_FALSE;
     }
 
     srcData = (unsigned char *) (*env)->GetPrimitiveArrayCritical(env, jpix,
@@ -104,27 +199,18 @@
     if (srcData == NULL) {
         (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT, JNI_ABORT);
         JNU_ThrowNullPointerException(env, "Null data array");
-        return;
-    }
-
-    cOffs = (int *) (*env)->GetPrimitiveArrayCritical(env, joffs, NULL);
-    if (cOffs == NULL) {
-        (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT, JNI_ABORT);
-        (*env)->ReleasePrimitiveArrayCritical(env, jpix, srcData, JNI_ABORT);
-        JNU_ThrowNullPointerException(env, "Null channel offset array");
-        return;
+        return JNI_FALSE;
     }
 
     dstData = (int *) (*env)->GetPrimitiveArrayCritical(env, jdata, NULL);
     if (dstData == NULL) {
         (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT, JNI_ABORT);
         (*env)->ReleasePrimitiveArrayCritical(env, jpix, srcData, JNI_ABORT);
-        (*env)->ReleasePrimitiveArrayCritical(env, joffs, cOffs, JNI_ABORT);
         JNU_ThrowNullPointerException(env, "Null tile data array");
-        return;
+        return JNI_FALSE;
     }
 
-    dstyP = dstData + cOffs[0] + y*sStride + x*pixelStride;
+    dstyP = dstData + dstDataOff + y*sStride + x*pixelStride;
     srcyP = srcData + off;
     for (yIdx = 0; yIdx < h; yIdx++, srcyP += scansize, dstyP+=sStride) {
         srcP = srcyP;
@@ -137,12 +223,12 @@
     /* Release the locked arrays */
     (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT,  JNI_ABORT);
     (*env)->ReleasePrimitiveArrayCritical(env, jpix, srcData, JNI_ABORT);
-    (*env)->ReleasePrimitiveArrayCritical(env, joffs, cOffs, JNI_ABORT);
     (*env)->ReleasePrimitiveArrayCritical(env, jdata, dstData, JNI_ABORT);
 
+    return JNI_TRUE;
 }
 
-JNIEXPORT jint JNICALL
+JNIEXPORT jboolean JNICALL
 Java_sun_awt_image_ImageRepresentation_setDiffICM(JNIEnv *env, jclass cls,
                                                   jint x, jint y, jint w,
                                                   jint h, jintArray jlut,
@@ -150,7 +236,7 @@
                                                   jobject jicm,
                                                   jbyteArray jpix, jint off,
                                                   jint scansize,
-                                                  jobject jbct, jint chanOff)
+                                                  jobject jbct, jint dstDataOff)
 {
     unsigned int *srcLUT = NULL;
     unsigned int *newLUT = NULL;
@@ -159,6 +245,8 @@
     int mapSize;
     jobject jdata = NULL;
     jobject jnewlut = NULL;
+    jint srcDataLength;
+    jint dstDataLength;
     unsigned char *srcData;
     unsigned char *dstData;
     unsigned char *dataP;
@@ -174,14 +262,23 @@
 
     if (JNU_IsNull(env, jlut)) {
         JNU_ThrowNullPointerException(env, "NullPointerException");
-        return 0;
+        return JNI_FALSE;
     }
 
     if (JNU_IsNull(env, jpix)) {
         JNU_ThrowNullPointerException(env, "NullPointerException");
-        return 0;
+        return JNI_FALSE;
+    }
+
+    if (x < 0 || w < 1 || (0x7fffffff - x) < w) {
+        return JNI_FALSE;
     }
 
+    if (y < 0 || h < 1 || (0x7fffffff - y) < h) {
+        return JNI_FALSE;
+    }
+
+
     sStride = (*env)->GetIntField(env, jbct, g_BCRscanstrID);
     pixelStride =(*env)->GetIntField(env, jbct, g_BCRpixstrID);
     jdata = (*env)->GetObjectField(env, jbct, g_BCRdataID);
@@ -193,13 +290,31 @@
            of byte data type, so we have to convert the image data
            to default representation.
         */
-        return 0;
+        return JNI_FALSE;
+    }
+
+    if (JNU_IsNull(env, jdata)) {
+        /* no destination buffer */
+        return JNI_FALSE;
     }
+
+    srcDataLength = (*env)->GetArrayLength(env, jpix);
+    dstDataLength = (*env)->GetArrayLength(env, jdata);
+
+    CHECK_STRIDE(y, h, sStride);
+    CHECK_STRIDE(x, w, pixelStride);
+
+    CHECK_DST(x, y);
+    CHECK_DST(x + w -1, y + h - 1);
+
+    /* check source array */
+    CHECK_SRC();
+
     srcLUT = (unsigned int *) (*env)->GetPrimitiveArrayCritical(env, jlut,
                                                                 NULL);
     if (srcLUT == NULL) {
         /* out of memory error already thrown */
-        return 0;
+        return JNI_FALSE;
     }
 
     newLUT = (unsigned int *) (*env)->GetPrimitiveArrayCritical(env, jnewlut,
@@ -208,7 +323,7 @@
         (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT,
                                               JNI_ABORT);
         /* out of memory error already thrown */
-        return 0;
+        return JNI_FALSE;
     }
 
     newNumLut = numLut;
@@ -219,7 +334,7 @@
         (*env)->ReleasePrimitiveArrayCritical(env, jlut, srcLUT,
                                               JNI_ABORT);
         (*env)->ReleasePrimitiveArrayCritical(env, jnewlut, newLUT, JNI_ABORT);
-        return 0;
+        return JNI_FALSE;
     }
 
     /* Don't need these any more */
@@ -239,7 +354,7 @@
                                                                   NULL);
     if (srcData == NULL) {
         /* out of memory error already thrown */
-        return 0;
+        return JNI_FALSE;
     }
 
     dstData = (unsigned char *) (*env)->GetPrimitiveArrayCritical(env, jdata,
@@ -247,10 +362,10 @@
     if (dstData == NULL) {
         (*env)->ReleasePrimitiveArrayCritical(env, jpix, srcData, JNI_ABORT);
         /* out of memory error already thrown */
-        return 0;
+        return JNI_FALSE;
     }
 
-    ydataP = dstData + chanOff + y*sStride + x*pixelStride;
+    ydataP = dstData + dstDataOff + y*sStride + x*pixelStride;
     ypixP  = srcData + off;
 
     for (i=0; i < h; i++) {
@@ -268,7 +383,7 @@
     (*env)->ReleasePrimitiveArrayCritical(env, jpix, srcData, JNI_ABORT);
     (*env)->ReleasePrimitiveArrayCritical(env, jdata, dstData, JNI_ABORT);
 
-    return 1;
+    return JNI_TRUE;
 }
 
 static int compareLUTs(unsigned int *lut1, int numLut1, int transIdx,