src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java
changeset 58308 b7192797f434
parent 48568 0255315ac8d4
--- a/src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java	Tue Sep 24 10:36:35 2019 -0700
+++ b/src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java	Tue Sep 24 22:57:28 2019 +0100
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2000, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2000, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -29,7 +29,9 @@
 import java.net.DatagramSocket;
 import java.net.DatagramPacket;
 import java.net.InetAddress;
+import java.net.InetSocketAddress;
 import java.net.Socket;
+import java.net.SocketTimeoutException;
 import java.security.SecureRandom;
 import javax.naming.*;
 
@@ -82,7 +84,7 @@
     private static final SecureRandom random = JCAUtil.getSecureRandom();
     private InetAddress[] servers;
     private int[] serverPorts;
-    private int timeout;                // initial timeout on UDP queries in ms
+    private int timeout;                // initial timeout on UDP and TCP queries in ms
     private int retries;                // number of UDP retries
 
     private final Object udpSocketLock = new Object();
@@ -100,7 +102,7 @@
     /*
      * Each server is of the form "server[:port]".  IPv6 literal host names
      * include delimiting brackets.
-     * "timeout" is the initial timeout interval (in ms) for UDP queries,
+     * "timeout" is the initial timeout interval (in ms) for queries,
      * and "retries" gives the number of retries per server.
      */
     public DnsClient(String[] servers, int timeout, int retries)
@@ -237,6 +239,7 @@
 
                             // Try each server, starting with the one that just
                             // provided the truncated message.
+                            int retryTimeout = (timeout * (1 << retry));
                             for (int j = 0; j < servers.length; j++) {
                                 int ij = (i + j) % servers.length;
                                 if (doNotRetry[ij]) {
@@ -244,7 +247,7 @@
                                 }
                                 try {
                                     Tcp tcp =
-                                        new Tcp(servers[ij], serverPorts[ij]);
+                                        new Tcp(servers[ij], serverPorts[ij], retryTimeout);
                                     byte[] msg2;
                                     try {
                                         msg2 = doTcpQuery(tcp, pkt);
@@ -327,7 +330,7 @@
         // Try each name server.
         for (int i = 0; i < servers.length; i++) {
             try {
-                Tcp tcp = new Tcp(servers[i], serverPorts[i]);
+                Tcp tcp = new Tcp(servers[i], serverPorts[i], timeout);
                 byte[] msg;
                 try {
                     msg = doTcpQuery(tcp, pkt);
@@ -462,11 +465,11 @@
      */
     private byte[] continueTcpQuery(Tcp tcp) throws IOException {
 
-        int lenHi = tcp.in.read();      // high-order byte of response length
+        int lenHi = tcp.read();      // high-order byte of response length
         if (lenHi == -1) {
             return null;        // EOF
         }
-        int lenLo = tcp.in.read();      // low-order byte of response length
+        int lenLo = tcp.read();      // low-order byte of response length
         if (lenLo == -1) {
             throw new IOException("Corrupted DNS response: bad length");
         }
@@ -474,7 +477,7 @@
         byte[] msg = new byte[len];
         int pos = 0;                    // next unfilled position in msg
         while (len > 0) {
-            int n = tcp.in.read(msg, pos, len);
+            int n = tcp.read(msg, pos, len);
             if (n == -1) {
                 throw new IOException(
                         "Corrupted DNS response: too little data");
@@ -682,20 +685,62 @@
 
 class Tcp {
 
-    private Socket sock;
-    java.io.InputStream in;
-    java.io.OutputStream out;
+    private final Socket sock;
+    private final java.io.InputStream in;
+    final java.io.OutputStream out;
+    private int timeoutLeft;
 
-    Tcp(InetAddress server, int port) throws IOException {
-        sock = new Socket(server, port);
-        sock.setTcpNoDelay(true);
-        out = new java.io.BufferedOutputStream(sock.getOutputStream());
-        in = new java.io.BufferedInputStream(sock.getInputStream());
+    Tcp(InetAddress server, int port, int timeout) throws IOException {
+        sock = new Socket();
+        try {
+            long start = System.currentTimeMillis();
+            sock.connect(new InetSocketAddress(server, port), timeout);
+            timeoutLeft = (int) (timeout - (System.currentTimeMillis() - start));
+            if (timeoutLeft <= 0)
+                throw new SocketTimeoutException();
+
+            sock.setTcpNoDelay(true);
+            out = new java.io.BufferedOutputStream(sock.getOutputStream());
+            in = new java.io.BufferedInputStream(sock.getInputStream());
+        } catch (Exception e) {
+            try {
+                sock.close();
+            } catch (IOException ex) {
+                e.addSuppressed(ex);
+            }
+            throw e;
+        }
     }
 
     void close() throws IOException {
         sock.close();
     }
+
+    private interface SocketReadOp {
+        int read() throws IOException;
+    }
+
+    private int readWithTimeout(SocketReadOp reader) throws IOException {
+        if (timeoutLeft <= 0)
+            throw new SocketTimeoutException();
+
+        sock.setSoTimeout(timeoutLeft);
+        long start = System.currentTimeMillis();
+        try {
+            return reader.read();
+        }
+        finally {
+            timeoutLeft -= System.currentTimeMillis() - start;
+        }
+    }
+
+    int read() throws IOException {
+        return readWithTimeout(() -> in.read());
+    }
+
+    int read(byte b[], int off, int len) throws IOException {
+        return readWithTimeout(() -> in.read(b, off, len));
+    }
 }
 
 /*