6916890: (sctp) SctpChannel.send may cause IAE if given a heap buffer with an offset
authorchegar
Mon, 18 Jan 2010 14:01:07 +0000
changeset 4677 1b6ce3fbc01b
parent 4676 4cdcce877ba7
child 4678 99fdf34405de
6916890: (sctp) SctpChannel.send may cause IAE if given a heap buffer with an offset Reviewed-by: alanb
jdk/src/solaris/classes/sun/nio/ch/SctpChannelImpl.java
jdk/src/solaris/classes/sun/nio/ch/SctpMultiChannelImpl.java
jdk/test/com/sun/nio/sctp/SctpChannel/Send.java
jdk/test/com/sun/nio/sctp/SctpMultiChannel/Send.java
--- a/jdk/src/solaris/classes/sun/nio/ch/SctpChannelImpl.java	Fri Jan 15 15:36:54 2010 -0800
+++ b/jdk/src/solaris/classes/sun/nio/ch/SctpChannelImpl.java	Mon Jan 18 14:01:07 2010 +0000
@@ -869,8 +869,8 @@
         public HandlerResult handleNotification(
                 AssociationChangeNotification not, T unused) {
             if (not.event().equals(
-                    AssociationChangeNotification.AssocChangeEvent.COMM_UP)) {
-                assert association == null;
+                    AssociationChangeNotification.AssocChangeEvent.COMM_UP) &&
+                    association == null) {
                 SctpAssocChange sac = (SctpAssocChange) not;
                 association = new SctpAssociationImpl
                        (sac.assocId(), sac.maxInStreams(), sac.maxOutStreams());
@@ -982,17 +982,17 @@
         SocketAddress target = messageInfo.address();
         boolean unordered = messageInfo.isUnordered();
         int ppid = messageInfo.payloadProtocolID();
-        int pos = src.position();
-        int lim = src.limit();
-
-        assert (pos <= lim && streamNumber >= 0);
-        int rem = (pos <= lim ? lim - pos : 0);
 
         if (src instanceof DirectBuffer)
-            return sendFromNativeBuffer(fd, src, rem, pos, target, streamNumber,
+            return sendFromNativeBuffer(fd, src, target, streamNumber,
                     unordered, ppid);
 
         /* Substitute a native buffer */
+        int pos = src.position();
+        int lim = src.limit();
+        assert (pos <= lim && streamNumber >= 0);
+
+        int rem = (pos <= lim ? lim - pos : 0);
         ByteBuffer bb = Util.getTemporaryDirectBuffer(rem);
         try {
             bb.put(src);
@@ -1000,7 +1000,7 @@
             /* Do not update src until we see how many bytes were written */
             src.position(pos);
 
-            int n = sendFromNativeBuffer(fd, bb, rem, pos, target, streamNumber,
+            int n = sendFromNativeBuffer(fd, bb, target, streamNumber,
                     unordered, ppid);
             if (n > 0) {
                 /* now update src */
@@ -1014,13 +1014,16 @@
 
     private int sendFromNativeBuffer(int fd,
                                      ByteBuffer bb,
-                                     int rem,
-                                     int pos,
                                      SocketAddress target,
                                      int streamNumber,
                                      boolean unordered,
                                      int ppid)
             throws IOException {
+        int pos = bb.position();
+        int lim = bb.limit();
+        assert (pos <= lim);
+        int rem = (pos <= lim ? lim - pos : 0);
+
         int written = send0(fd, ((DirectBuffer)bb).address() + pos,
                             rem, target, -1 /*121*/, streamNumber, unordered, ppid);
         if (written > 0)
--- a/jdk/src/solaris/classes/sun/nio/ch/SctpMultiChannelImpl.java	Fri Jan 15 15:36:54 2010 -0800
+++ b/jdk/src/solaris/classes/sun/nio/ch/SctpMultiChannelImpl.java	Mon Jan 18 14:01:07 2010 +0000
@@ -842,16 +842,17 @@
         int streamNumber = messageInfo.streamNumber();
         boolean unordered = messageInfo.isUnordered();
         int ppid = messageInfo.payloadProtocolID();
+
+        if (src instanceof DirectBuffer)
+            return sendFromNativeBuffer(fd, src, target, assocId,
+                    streamNumber, unordered, ppid);
+
+        /* Substitute a native buffer */
         int pos = src.position();
         int lim = src.limit();
         assert (pos <= lim && streamNumber >= 0);
+
         int rem = (pos <= lim ? lim - pos : 0);
-
-        if (src instanceof DirectBuffer)
-            return sendFromNativeBuffer(fd, src, rem, pos, target, assocId,
-                    streamNumber, unordered, ppid);
-
-        /* Substitute a native buffer */
         ByteBuffer bb = Util.getTemporaryDirectBuffer(rem);
         try {
             bb.put(src);
@@ -859,7 +860,7 @@
             /* Do not update src until we see how many bytes were written */
             src.position(pos);
 
-            int n = sendFromNativeBuffer(fd, bb, rem, pos, target, assocId,
+            int n = sendFromNativeBuffer(fd, bb, target, assocId,
                     streamNumber, unordered, ppid);
             if (n > 0) {
                 /* now update src */
@@ -873,14 +874,17 @@
 
     private int sendFromNativeBuffer(int fd,
                                      ByteBuffer bb,
-                                     int rem,
-                                     int pos,
                                      SocketAddress target,
                                      int assocId,
                                      int streamNumber,
                                      boolean unordered,
                                      int ppid)
             throws IOException {
+        int pos = bb.position();
+        int lim = bb.limit();
+        assert (pos <= lim);
+        int rem = (pos <= lim ? lim - pos : 0);
+
         int written = send0(fd, ((DirectBuffer)bb).address() + pos,
                             rem, target, assocId, streamNumber, unordered, ppid);
         if (written > 0)
--- a/jdk/test/com/sun/nio/sctp/SctpChannel/Send.java	Fri Jan 15 15:36:54 2010 -0800
+++ b/jdk/test/com/sun/nio/sctp/SctpChannel/Send.java	Mon Jan 18 14:01:07 2010 +0000
@@ -112,9 +112,6 @@
             /* Receive CommUp */
             channel.receive(buffer, null, handler);
 
-            /* save for TEST 8 */
-            Association association = channel.association();
-
             /* TEST 2: send small message */
             int streamNumber = 0;
             debug("sending on stream number: " + streamNumber);
@@ -250,6 +247,29 @@
                 pass();
                 debug("OK, caught " + e);
             }
+
+            /* TEST 9: Send from heap buffer to force implementation to
+             * substitute with a native buffer, then check that its position
+             * is updated correctly */
+            buffer.clear();
+            info = MessageInfo.createOutgoing(null, 0);
+            buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
+            buffer.flip();
+            final int offset = 1;
+            buffer.position(offset);
+            remaining = buffer.remaining();
+
+            debug("sending small message: " + buffer);
+            try {
+                sent = channel.send(buffer, info);
+
+                check(sent == remaining, "sent should be equal to remaining");
+                check(buffer.position() == (offset + sent),
+                        "buffers position should have been incremented by sent");
+            } catch (IllegalArgumentException iae) {
+                fail(iae + ", Error updating buffers position");
+            }
+
         } catch (IOException ioe) {
             unexpected(ioe);
         } finally {
@@ -335,6 +355,30 @@
                 /* TEST 7 ++ */
                 sc2 = ssc.accept();
 
+                /* TEST 9 */
+                ByteBuffer expected = ByteBuffer.allocate(Util.SMALL_BUFFER);
+                expected.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
+                expected.flip();
+                final int offset = 1;
+                expected.position(offset);
+                buffer.clear();
+                do {
+                    info = sc2.receive(buffer, null, null);
+                    if (info == null) {
+                        fail("Server: unexpected null from receive");
+                        return;
+                    }
+                } while (!info.isComplete());
+
+                buffer.flip();
+                check(info != null, "info is null");
+                check(info.streamNumber() == 0, "message not sent on the correct stream");
+                check(info.bytes() == expected.remaining(),
+                      "bytes received not equal to message length");
+                check(info.bytes() == buffer.remaining(), "bytes != remaining");
+                check(expected.equals(buffer),
+                    "received message not the same as sent message");
+
                 clientFinishedLatch.await(10L, TimeUnit.SECONDS);
                 serverFinishedLatch.countDown();
             } catch (IOException ioe) {
--- a/jdk/test/com/sun/nio/sctp/SctpMultiChannel/Send.java	Fri Jan 15 15:36:54 2010 -0800
+++ b/jdk/test/com/sun/nio/sctp/SctpMultiChannel/Send.java	Mon Jan 18 14:01:07 2010 +0000
@@ -185,6 +185,27 @@
             /* TEST 5: getRemoteAddresses(Association) */
             channel.getRemoteAddresses(assoc);
 
+            /* TEST 6: Send from heap buffer to force implementation to
+             * substitute with a native buffer, then check that its position
+             * is updated correctly */
+            info = MessageInfo.createOutgoing(assoc, null, 0);
+            buffer.clear();
+            buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
+            buffer.flip();
+            final int offset = 1;
+            buffer.position(offset);
+            remaining = buffer.remaining();
+
+            try {
+                sent = channel.send(buffer, info);
+
+                check(sent == remaining, "sent should be equal to remaining");
+                check(buffer.position() == (offset + sent),
+                        "buffers position should have been incremented by sent");
+            } catch (IllegalArgumentException iae) {
+                fail(iae + ", Error updating buffers position");
+            }
+
         } catch (IOException ioe) {
             unexpected(ioe);
         } finally {
@@ -284,6 +305,30 @@
                 bytes = serverChannel.send(buffer, info);
                 debug("Server: sent " + bytes + "bytes");
 
+                /* TEST 6 */
+                ByteBuffer expected = ByteBuffer.allocate(Util.SMALL_BUFFER);
+                expected.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
+                expected.flip();
+                final int offset = 1;
+                expected.position(offset);
+                buffer.clear();
+                do {
+                    info = serverChannel.receive(buffer, null, null);
+                    if (info == null) {
+                        fail("Server: unexpected null from receive");
+                        return;
+                    }
+                } while (!info.isComplete());
+
+                buffer.flip();
+                check(info != null, "info is null");
+                check(info.streamNumber() == 0, "message not sent on the correct stream");
+                check(info.bytes() == expected.remaining(),
+                    "bytes received not equal to message length");
+                check(info.bytes() == buffer.remaining(), "bytes != remaining");
+                check(expected.equals(buffer),
+                    "received message not the same as sent message");
+
                 clientFinishedLatch.await(10L, TimeUnit.SECONDS);
                 serverFinishedLatch.countDown();
             } catch (IOException ioe) {