jdk/src/share/classes/com/sun/jndi/dns/DnsClient.java
changeset 2 90ce3da70b43
child 5506 202f599c92aa
child 5624 302c8bd0d942
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/jdk/src/share/classes/com/sun/jndi/dns/DnsClient.java	Sat Dec 01 00:00:00 2007 +0000
@@ -0,0 +1,699 @@
+/*
+ * Copyright 2000-2004 Sun Microsystems, Inc.  All Rights Reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Sun designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Sun in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Sun Microsystems, Inc., 4150 Network Circle, Santa Clara,
+ * CA 95054 USA or visit www.sun.com if you need additional information or
+ * have any questions.
+ */
+
+package com.sun.jndi.dns;
+
+import java.io.IOException;
+import java.net.DatagramSocket;
+import java.net.DatagramPacket;
+import java.net.InetAddress;
+import java.net.Socket;
+import javax.naming.*;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.HashMap;
+import java.util.Set;
+import java.util.HashSet;
+
+// Some of this code began life as part of sun.javaos.net.DnsClient
+// originally by sritchie@eng 1/96.  It was first hacked up for JNDI
+// use by caveh@eng 6/97.
+
+
+/**
+ * The DnsClient class performs DNS client operations in support of DnsContext.
+ *
+ */
+
+public class DnsClient {
+
+    // DNS packet header field offsets
+    private static final int IDENT_OFFSET = 0;
+    private static final int FLAGS_OFFSET = 2;
+    private static final int NUMQ_OFFSET  = 4;
+    private static final int NUMANS_OFFSET = 6;
+    private static final int NUMAUTH_OFFSET = 8;
+    private static final int NUMADD_OFFSET = 10;
+    private static final int DNS_HDR_SIZE = 12;
+
+    // DNS response codes
+    private static final int NO_ERROR       = 0;
+    private static final int FORMAT_ERROR   = 1;
+    private static final int SERVER_FAILURE = 2;
+    private static final int NAME_ERROR     = 3;
+    private static final int NOT_IMPL       = 4;
+    private static final int REFUSED        = 5;
+
+    private static final String[] rcodeDescription = {
+        "No error",
+        "DNS format error",
+        "DNS server failure",
+        "DNS name not found",
+        "DNS operation not supported",
+        "DNS service refused"
+    };
+
+    private static final int DEFAULT_PORT = 53;
+    private InetAddress[] servers;
+    private int[] serverPorts;
+    private int timeout;                // initial timeout on UDP queries in ms
+    private int retries;                // number of UDP retries
+
+    private DatagramSocket udpSocket;
+
+    // Requests sent
+    private Set<Integer> reqs;
+
+    // Responses received
+    private Map<Integer, byte[]> resps;
+
+    //-------------------------------------------------------------------------
+
+    /*
+     * Each server is of the form "server[:port]".  IPv6 literal host names
+     * include delimiting brackets.
+     * "timeout" is the initial timeout interval (in ms) for UDP queries,
+     * and "retries" gives the number of retries per server.
+     */
+    public DnsClient(String[] servers, int timeout, int retries)
+            throws NamingException {
+        this.timeout = timeout;
+        this.retries = retries;
+        try {
+            udpSocket = new DatagramSocket();
+        } catch (java.net.SocketException e) {
+            NamingException ne = new ConfigurationException();
+            ne.setRootCause(e);
+            throw ne;
+        }
+
+        this.servers = new InetAddress[servers.length];
+        serverPorts = new int[servers.length];
+
+        for (int i = 0; i < servers.length; i++) {
+
+            // Is optional port given?
+            int colon = servers[i].indexOf(':',
+                                           servers[i].indexOf(']') + 1);
+
+            serverPorts[i] = (colon < 0)
+                ? DEFAULT_PORT
+                : Integer.parseInt(servers[i].substring(colon + 1));
+            String server = (colon < 0)
+                ? servers[i]
+                : servers[i].substring(0, colon);
+            try {
+                this.servers[i] = InetAddress.getByName(server);
+            } catch (java.net.UnknownHostException e) {
+                NamingException ne = new ConfigurationException(
+                        "Unknown DNS server: " + server);
+                ne.setRootCause(e);
+                throw ne;
+            }
+        }
+        reqs = Collections.synchronizedSet(new HashSet<Integer>());
+        resps = Collections.synchronizedMap(new HashMap<Integer, byte[]>());
+    }
+
+    protected void finalize() {
+        close();
+    }
+
+    // A lock to access the request and response queues in tandem.
+    private Object queuesLock = new Object();
+
+    public void close() {
+        udpSocket.close();
+        synchronized (queuesLock) {
+            reqs.clear();
+            resps.clear();
+        }
+    }
+
+
+    private int ident = 0;              // used to set the msg ID field
+    private Object identLock = new Object();
+
+    /*
+     * If recursion is true, recursion is requested on the query.
+     * If auth is true, only authoritative responses are accepted; other
+     * responses throw NameNotFoundException.
+     */
+    ResourceRecords query(DnsName fqdn, int qclass, int qtype,
+                          boolean recursion, boolean auth)
+            throws NamingException {
+
+        int xid;
+        synchronized (identLock) {
+            ident = 0xFFFF & (ident + 1);
+            xid = ident;
+        }
+
+        // enqueue the outstanding request
+        reqs.add(xid);
+
+        Packet pkt = makeQueryPacket(fqdn, xid, qclass, qtype, recursion);
+
+        Exception caughtException = null;
+        boolean[] doNotRetry = new boolean[servers.length];
+
+        //
+        // The UDP retry strategy is to try the 1st server, and then
+        // each server in order. If no answer, double the timeout
+        // and try each server again.
+        //
+        for (int retry = 0; retry < retries; retry++) {
+
+            // Try each name server.
+            for (int i = 0; i < servers.length; i++) {
+                if (doNotRetry[i]) {
+                    continue;
+                }
+
+                // send the request packet and wait for a response.
+                try {
+                    if (debug) {
+                        dprint("SEND ID (" + (retry + 1) + "): " + xid);
+                    }
+
+                    byte[] msg = null;
+                    msg = doUdpQuery(pkt, servers[i], serverPorts[i],
+                                        retry, xid);
+                    //
+                    // If the matching response is not got within the
+                    // given timeout, check if the response was enqueued
+                    // by some other thread, if not proceed with the next
+                    // server or retry.
+                    //
+                    if (msg == null) {
+                        if (resps.size() > 0) {
+                            msg = lookupResponse(xid);
+                        }
+                        if (msg == null) { // try next server or retry
+                            continue;
+                        }
+                    }
+                    Header hdr = new Header(msg, msg.length);
+
+                    if (auth && !hdr.authoritative) {
+                        caughtException = new NameNotFoundException(
+                                "DNS response not authoritative");
+                        doNotRetry[i] = true;
+                        continue;
+                    }
+                    if (hdr.truncated) {    // message is truncated -- try TCP
+
+                        // Try each server, starting with the one that just
+                        // provided the truncated message.
+                        for (int j = 0; j < servers.length; j++) {
+                            int ij = (i + j) % servers.length;
+                            if (doNotRetry[ij]) {
+                                continue;
+                            }
+                            try {
+                                Tcp tcp =
+                                    new Tcp(servers[ij], serverPorts[ij]);
+                                byte[] msg2;
+                                try {
+                                    msg2 = doTcpQuery(tcp, pkt);
+                                } finally {
+                                    tcp.close();
+                                }
+                                Header hdr2 = new Header(msg2, msg2.length);
+                                if (hdr2.query) {
+                                    throw new CommunicationException(
+                                        "DNS error: expecting response");
+                                }
+                                checkResponseCode(hdr2);
+
+                                if (!auth || hdr2.authoritative) {
+                                    // Got a valid response
+                                    hdr = hdr2;
+                                    msg = msg2;
+                                    break;
+                                } else {
+                                    doNotRetry[ij] = true;
+                                }
+                            } catch (Exception e) {
+                                // Try next server, or use UDP response
+                            }
+                        } // servers
+                    }
+                    return new ResourceRecords(msg, msg.length, hdr, false);
+
+                } catch (IOException e) {
+                    if (debug) {
+                        dprint("Caught IOException:" + e);
+                    }
+                    if (caughtException == null) {
+                        caughtException = e;
+                    }
+                    // Use reflection to allow pre-1.4 compilation.
+                    // This won't be needed much longer.
+                    if (e.getClass().getName().equals(
+                            "java.net.PortUnreachableException")) {
+                        doNotRetry[i] = true;
+                    }
+                } catch (NameNotFoundException e) {
+                    throw e;
+                } catch (CommunicationException e) {
+                    if (caughtException == null) {
+                        caughtException = e;
+                    }
+                } catch (NamingException e) {
+                    if (caughtException == null) {
+                        caughtException = e;
+                    }
+                    doNotRetry[i] = true;
+                }
+            } // servers
+        } // retries
+
+        reqs.remove(xid);
+        if (caughtException instanceof NamingException) {
+            throw (NamingException) caughtException;
+        }
+        // A network timeout or other error occurred.
+        NamingException ne = new CommunicationException("DNS error");
+        ne.setRootCause(caughtException);
+        throw ne;
+    }
+
+    ResourceRecords queryZone(DnsName zone, int qclass, boolean recursion)
+            throws NamingException {
+
+        int xid;
+        synchronized (identLock) {
+            ident = 0xFFFF & (ident + 1);
+            xid = ident;
+        }
+        Packet pkt = makeQueryPacket(zone, xid, qclass,
+                                     ResourceRecord.QTYPE_AXFR, recursion);
+        Exception caughtException = null;
+
+        // Try each name server.
+        for (int i = 0; i < servers.length; i++) {
+            try {
+                Tcp tcp = new Tcp(servers[i], serverPorts[i]);
+                byte[] msg;
+                try {
+                    msg = doTcpQuery(tcp, pkt);
+                    Header hdr = new Header(msg, msg.length);
+                    // Check only rcode as per
+                    // draft-ietf-dnsext-axfr-clarify-04
+                    checkResponseCode(hdr);
+                    ResourceRecords rrs =
+                        new ResourceRecords(msg, msg.length, hdr, true);
+                    if (rrs.getFirstAnsType() != ResourceRecord.TYPE_SOA) {
+                        throw new CommunicationException(
+                                "DNS error: zone xfer doesn't begin with SOA");
+                    }
+
+                    if (rrs.answer.size() == 1 ||
+                            rrs.getLastAnsType() != ResourceRecord.TYPE_SOA) {
+                        // The response is split into multiple DNS messages.
+                        do {
+                            msg = continueTcpQuery(tcp);
+                            if (msg == null) {
+                                throw new CommunicationException(
+                                        "DNS error: incomplete zone transfer");
+                            }
+                            hdr = new Header(msg, msg.length);
+                            checkResponseCode(hdr);
+                            rrs.add(msg, msg.length, hdr);
+                        } while (rrs.getLastAnsType() !=
+                                 ResourceRecord.TYPE_SOA);
+                    }
+
+                    // Delete the duplicate SOA record.
+                    rrs.answer.removeElementAt(rrs.answer.size() - 1);
+                    return rrs;
+
+                } finally {
+                    tcp.close();
+                }
+
+            } catch (IOException e) {
+                caughtException = e;
+            } catch (NameNotFoundException e) {
+                throw e;
+            } catch (NamingException e) {
+                caughtException = e;
+            }
+        }
+        if (caughtException instanceof NamingException) {
+            throw (NamingException) caughtException;
+        }
+        NamingException ne = new CommunicationException(
+                "DNS error during zone transfer");
+        ne.setRootCause(caughtException);
+        throw ne;
+    }
+
+
+    /**
+     * Tries to retreive an UDP packet matching the given xid
+     * received within the timeout.
+     * If a packet with different xid is received, the received packet
+     * is enqueued with the corresponding xid in 'resps'.
+     */
+    private byte[] doUdpQuery(Packet pkt, InetAddress server,
+                                     int port, int retry, int xid)
+            throws IOException, NamingException {
+
+        int minTimeout = 50; // msec after which there are no retries.
+
+        synchronized (udpSocket) {
+            DatagramPacket opkt = new DatagramPacket(
+                    pkt.getData(), pkt.length(), server, port);
+            DatagramPacket ipkt = new DatagramPacket(new byte[8000], 8000);
+            udpSocket.connect(server, port);
+            int pktTimeout = (timeout * (1 << retry));
+            try {
+                udpSocket.send(opkt);
+
+                // timeout remaining after successive 'receive()'
+                int timeoutLeft = pktTimeout;
+                int cnt = 0;
+                do {
+                    if (debug) {
+                       cnt++;
+                        dprint("Trying RECEIVE(" +
+                                cnt + ") retry(" + (retry + 1) +
+                                ") for:" + xid  + "    sock-timeout:" +
+                                timeoutLeft + " ms.");
+                    }
+                    udpSocket.setSoTimeout(timeoutLeft);
+                    long start = System.currentTimeMillis();
+                    udpSocket.receive(ipkt);
+                    long end = System.currentTimeMillis();
+
+                    byte[] data = new byte[ipkt.getLength()];
+                    data = ipkt.getData();
+                    if (isMatchResponse(data, xid)) {
+                        return data;
+                    }
+                    timeoutLeft = pktTimeout - ((int) (end - start));
+                } while (timeoutLeft > minTimeout);
+
+            } finally {
+                udpSocket.disconnect();
+            }
+            return null; // no matching packet received within the timeout
+        }
+    }
+
+    /*
+     * Sends a TCP query, and returns the first DNS message in the response.
+     */
+    private byte[] doTcpQuery(Tcp tcp, Packet pkt) throws IOException {
+
+        int len = pkt.length();
+        // Send 2-byte message length, then send message.
+        tcp.out.write(len >> 8);
+        tcp.out.write(len);
+        tcp.out.write(pkt.getData(), 0, len);
+        tcp.out.flush();
+
+        byte[] msg = continueTcpQuery(tcp);
+        if (msg == null) {
+            throw new IOException("DNS error: no response");
+        }
+        return msg;
+    }
+
+    /*
+     * Returns the next DNS message from the TCP socket, or null on EOF.
+     */
+    private byte[] continueTcpQuery(Tcp tcp) throws IOException {
+
+        int lenHi = tcp.in.read();      // high-order byte of response length
+        if (lenHi == -1) {
+            return null;        // EOF
+        }
+        int lenLo = tcp.in.read();      // low-order byte of response length
+        if (lenLo == -1) {
+            throw new IOException("Corrupted DNS response: bad length");
+        }
+        int len = (lenHi << 8) | lenLo;
+        byte[] msg = new byte[len];
+        int pos = 0;                    // next unfilled position in msg
+        while (len > 0) {
+            int n = tcp.in.read(msg, pos, len);
+            if (n == -1) {
+                throw new IOException(
+                        "Corrupted DNS response: too little data");
+            }
+            len -= n;
+            pos += n;
+        }
+        return msg;
+    }
+
+    private Packet makeQueryPacket(DnsName fqdn, int xid,
+                                   int qclass, int qtype, boolean recursion) {
+        int qnameLen = fqdn.getOctets();
+        int pktLen = DNS_HDR_SIZE + qnameLen + 4;
+        Packet pkt = new Packet(pktLen);
+
+        short flags = recursion ? Header.RD_BIT : 0;
+
+        pkt.putShort(xid, IDENT_OFFSET);
+        pkt.putShort(flags, FLAGS_OFFSET);
+        pkt.putShort(1, NUMQ_OFFSET);
+        pkt.putShort(0, NUMANS_OFFSET);
+        pkt.putInt(0, NUMAUTH_OFFSET);
+
+        makeQueryName(fqdn, pkt, DNS_HDR_SIZE);
+        pkt.putShort(qtype, DNS_HDR_SIZE + qnameLen);
+        pkt.putShort(qclass, DNS_HDR_SIZE + qnameLen + 2);
+
+        return pkt;
+    }
+
+    // Builds a query name in pkt according to the RFC spec.
+    private void makeQueryName(DnsName fqdn, Packet pkt, int off) {
+
+        // Loop through labels, least-significant first.
+        for (int i = fqdn.size() - 1; i >= 0; i--) {
+            String label = fqdn.get(i);
+            int len = label.length();
+
+            pkt.putByte(len, off++);
+            for (int j = 0; j < len; j++) {
+                pkt.putByte(label.charAt(j), off++);
+            }
+        }
+        if (!fqdn.hasRootLabel()) {
+            pkt.putByte(0, off);
+        }
+    }
+
+    //-------------------------------------------------------------------------
+
+    private byte[] lookupResponse(Integer xid) throws NamingException {
+        //
+        // Check the queued responses: some other thread in between
+        // received the response for this request.
+        //
+        if (debug) {
+            dprint("LOOKUP for: " + xid +
+                "\tResponse Q:" + resps);
+        }
+        byte[] pkt;
+        if ((pkt = (byte[]) resps.get(xid)) != null) {
+            synchronized (queuesLock) {
+                resps.remove(xid);
+                reqs.remove(xid);
+            }
+            checkResponseCode(new Header(pkt, pkt.length));
+
+            if (debug) {
+                dprint("FOUND (" + Thread.currentThread() +
+                    ") for:" + xid);
+            }
+        }
+        return pkt;
+    }
+
+    /*
+     * Checks the header of an incoming DNS response.
+     * Returns true if it matches the given xid and throws a naming
+     * exception, if appropriate, based on the response code.
+     */
+    private boolean isMatchResponse(byte[] pkt, int xid)
+                throws NamingException {
+
+        Header hdr = new Header(pkt, pkt.length);
+        if (hdr.query) {
+            throw new CommunicationException("DNS error: expecting response");
+        }
+
+        if (!reqs.contains(xid)) { // already received, ignore the response
+            return false;
+        }
+
+        // common case- the request sent matches the subsequent response read
+        if (hdr.xid == xid) {
+            if (debug) {
+                dprint("XID MATCH:" + xid);
+            }
+
+            // remove the response for the xid if received by some other thread.
+            synchronized (queuesLock) {
+                resps.remove(xid);
+                reqs.remove(xid);
+            }
+            checkResponseCode(hdr);
+            return true;
+        }
+
+        //
+        // xid mis-match: enqueue the response, it may belong to some other
+        // thread that has not yet had a chance to read its response.
+        // enqueue only the first response, responses for retries are ignored.
+        //
+        synchronized (queuesLock) {
+            if (reqs.contains(xid)) { // enqueue only the first response
+                resps.put(xid, pkt);
+            }
+        }
+
+        if (debug) {
+            dprint("NO-MATCH SEND ID:" +
+                                xid + " RECVD ID:" + hdr.xid +
+                                "    Response Q:" + resps +
+                                "    Reqs size:" + reqs.size());
+        }
+        return false;
+    }
+
+    /*
+     * Throws an exception if appropriate for the response code of a
+     * given header.
+     */
+    private void checkResponseCode(Header hdr) throws NamingException {
+
+        int rcode = hdr.rcode;
+        if (rcode == NO_ERROR) {
+            return;
+        }
+        String msg = (rcode < rcodeDescription.length)
+            ? rcodeDescription[rcode]
+            : "DNS error";
+        msg += " [response code " + rcode + "]";
+
+        switch (rcode) {
+        case SERVER_FAILURE:
+            throw new ServiceUnavailableException(msg);
+        case NAME_ERROR:
+            throw new NameNotFoundException(msg);
+        case NOT_IMPL:
+        case REFUSED:
+            throw new OperationNotSupportedException(msg);
+        case FORMAT_ERROR:
+        default:
+            throw new NamingException(msg);
+        }
+    }
+
+    //-------------------------------------------------------------------------
+
+    private static boolean debug = false;
+
+    public static void setDebug(boolean flag) {
+        debug = flag;
+    }
+
+    private static void dprint(String mess) {
+        if (debug) {
+            System.err.println("DNS: " + mess);
+        }
+    }
+
+}
+
+class Tcp {
+
+    private Socket sock;
+    java.io.InputStream in;
+    java.io.OutputStream out;
+
+    Tcp(InetAddress server, int port) throws IOException {
+        sock = new Socket(server, port);
+        sock.setTcpNoDelay(true);
+        out = new java.io.BufferedOutputStream(sock.getOutputStream());
+        in = new java.io.BufferedInputStream(sock.getInputStream());
+    }
+
+    void close() throws IOException {
+        sock.close();
+    }
+}
+
+/*
+ * javaos emulation -cj
+ */
+class Packet {
+        byte buf[];
+
+        Packet(int len) {
+                buf = new byte[len];
+        }
+
+        Packet(byte data[], int len) {
+                buf = new byte[len];
+                System.arraycopy(data, 0, buf, 0, len);
+        }
+
+        void putInt(int x, int off) {
+                buf[off + 0] = (byte)(x >> 24);
+                buf[off + 1] = (byte)(x >> 16);
+                buf[off + 2] = (byte)(x >> 8);
+                buf[off + 3] = (byte)x;
+        }
+
+        void putShort(int x, int off) {
+                buf[off + 0] = (byte)(x >> 8);
+                buf[off + 1] = (byte)x;
+        }
+
+        void putByte(int x, int off) {
+                buf[off] = (byte)x;
+        }
+
+        void putBytes(byte src[], int src_offset, int dst_offset, int len) {
+                System.arraycopy(src, src_offset, buf, dst_offset, len);
+        }
+
+        int length() {
+                return buf.length;
+        }
+
+        byte[] getData() {
+                return buf;
+        }
+}