src/jdk.dns.client/share/classes/jdk/dns/client/internal/DnsClient.java
author aefimov
Thu, 14 Nov 2019 23:16:40 +0000
branchaefimov-dns-client-branch
changeset 59101 258033faefc9
parent 58971 465a15dd6bed
permissions -rw-r--r--
aefimov-dns-client-branch: Fix timeouts, add platform specific file locations

/*
 * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

package jdk.dns.client.internal;

import jdk.dns.client.ex.DnsCommunicationException;
import jdk.dns.client.ex.DnsNameNotFoundException;
import jdk.dns.client.ex.DnsResolverException;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.PortUnreachableException;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.nio.channels.DatagramChannel;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantLock;

import sun.net.util.IPAddressUtil;
import sun.security.jca.JCAUtil;

// Some of this code began life as part of sun.javaos.net.DnsClient
// originally by sritchie@eng 1/96.  It was first hacked up for JNDI
// use by caveh@eng 6/97.


/**
 * The DnsClient class performs DNS client operations in support of DnsContext.
 */

public class DnsClient {

    // DNS packet header field offsets
    private static final int IDENT_OFFSET = 0;
    private static final int FLAGS_OFFSET = 2;
    private static final int NUMQ_OFFSET = 4;
    private static final int NUMANS_OFFSET = 6;
    private static final int NUMAUTH_OFFSET = 8;
    private static final int NUMADD_OFFSET = 10;
    private static final int DNS_HDR_SIZE = 12;

    // DNS response codes
    private static final int NO_ERROR = 0;
    private static final int FORMAT_ERROR = 1;
    private static final int SERVER_FAILURE = 2;
    private static final int NAME_ERROR = 3;
    private static final int NOT_IMPL = 4;
    private static final int REFUSED = 5;

    private static final String[] rcodeDescription = {
            "No error",
            "DNS format error",
            "DNS server failure",
            "DNS name not found",
            "DNS operation not supported",
            "DNS service refused"
    };

    private static final int DEFAULT_PORT = 53;
    private static final int TRANSACTION_ID_BOUND = 0x10000;
    private List<InetAddress> servers;
    private List<Integer> serverPorts;
    private int timeout;                // initial timeout on UDP and TCP queries in ms
    private int retries;                // number of UDP retries

    private static final SecureRandom random;

    static {
        var pa = (PrivilegedAction<SecureRandom>) () -> JCAUtil.getSecureRandom();
        random = System.getSecurityManager() == null ? pa.run()
                : AccessController.doPrivileged(pa);
    }

    private static final DnsDatagramChannelFactory factory =
            new DnsDatagramChannelFactory(random);

    // Requests sent
    private Map<Integer, ResourceRecord> reqs;

    // Responses received
    private Map<Integer, byte[]> resps;

    //-------------------------------------------------------------------------

    /*
     * Each server is of the form "server[:port]".  IPv6 literal host names
     * include delimiting brackets.
     * "timeout" is the initial timeout interval (in ms) for queries,
     * and "retries" gives the number of retries per server.
     */
    public DnsClient(List<String> servers, int timeout, int retries) {
        this.timeout = timeout;
        this.retries = retries;
        var serversList = new ArrayList<InetAddress>();
        var serverPortsList = new ArrayList<Integer>();

        if (DEBUG) {
            System.err.println("DNS Client: servers list:" + servers);
        }

        for (String serverString : servers) {

            // Is optional port given?
            int colon = serverString.indexOf(':',
                    serverString.indexOf(']') + 1);

            int serverPort = (colon < 0) ? DEFAULT_PORT
                    : Integer.parseInt(serverString.substring(colon + 1));
            String server = (colon < 0)
                    ? serverString
                    : serverString.substring(0, colon);

            var pa = (PrivilegedAction<byte[]>) () -> {
                if (IPAddressUtil.isIPv4LiteralAddress(server)) {
                    return IPAddressUtil.textToNumericFormatV4(server);
                } else if (IPAddressUtil.isIPv6LiteralAddress(server)) {
                    return IPAddressUtil.textToNumericFormatV6(server);
                }
                return null;
            };
            byte[] addr = System.getSecurityManager() == null ?
                    pa.run() : AccessController.doPrivileged(pa);
            if (addr != null) {
                try {
                    serversList.add(InetAddress.getByAddress(server, addr));
                    serverPortsList.add(serverPort);
                } catch (UnknownHostException e) {
                    // Malformed IP address is specified - will ignore it
                }
            }
        }
        this.servers = Collections.unmodifiableList(serversList);
        this.serverPorts = Collections.unmodifiableList(serverPortsList);
        reqs = Collections.synchronizedMap(
                new HashMap<>());
        resps = Collections.synchronizedMap(new HashMap<>());
    }

    DatagramChannel getDatagramChannel() throws DnsResolverException {
        try {
            return factory.open();
        } catch (java.net.SocketException e) {
            throw new DnsResolverException("Can't create datagram channel", e);
        }
    }

    @SuppressWarnings("deprecation")
    protected void finalize() {
        close();
    }

    // A lock to access the request and response queues in tandem.
    private ReentrantLock queuesLock = new ReentrantLock();

    public void close() {
        queuesLock.lock();
        try {
            reqs.clear();
            resps.clear();
        } finally {
            queuesLock.unlock();
        }
    }

    /*
     * If recursion is true, recursion is requested on the query.
     * If auth is true, only authoritative responses are accepted; other
     * responses throw NameNotFoundException.
     */
    ResourceRecords query(DnsName fqdn, int qclass, int qtype,
                          boolean recursion, boolean auth)
            throws DnsResolverException {

        int xid;
        Packet pkt;
        ResourceRecord collision;

        do {
            // Generate a random transaction ID
            xid = random.nextInt(TRANSACTION_ID_BOUND);
            pkt = makeQueryPacket(fqdn, xid, qclass, qtype, recursion);

            // enqueue the outstanding request
            collision = reqs.putIfAbsent(xid, new ResourceRecord(pkt.getData(),
                    pkt.length(), Header.HEADER_SIZE, true, false));

        } while (collision != null);

        Exception caughtException = null;
        boolean[] doNotRetry = new boolean[servers.size()];

        try {
            //
            // The UDP retry strategy is to try the 1st server, and then
            // each server in order. If no answer, double the timeout
            // and try each server again.
            //
            for (int retry = 0; retry < retries; retry++) {

                // Try each name server.
                for (int i = 0; i < servers.size(); i++) {
                    if (doNotRetry[i]) {
                        continue;
                    }

                    // send the request packet and wait for a response.
                    try {
                        if (DEBUG) {
                            dprint("SEND ID (" + (retry + 1) + "): " + xid);
                        }

                        byte[] msg = doUdpQuery(pkt, servers.get(i), serverPorts.get(i), retry, xid);
                        //
                        // If the matching response is not got within the
                        // given timeout, check if the response was enqueued
                        // by some other thread, if not proceed with the next
                        // server or retry.
                        //
                        if (msg == null) {
                            if (resps.size() > 0) {
                                msg = lookupResponse(xid);
                            }
                            if (msg == null) { // try next server or retry
                                continue;
                            }
                        }
                        Header hdr = new Header(msg, msg.length);

                        if (auth && !hdr.authoritative) {
                            caughtException = new DnsResolverException("DNS response not authoritative");
                            doNotRetry[i] = true;
                            continue;
                        }
                        if (hdr.truncated) {  // message is truncated -- try TCP

                            // Try each server, starting with the one that just
                            // provided the truncated message.
                            int retryTimeout = (timeout * (1 << retry));
                            for (int j = 0; j < servers.size(); j++) {
                                int ij = (i + j) % servers.size();
                                if (doNotRetry[ij]) {
                                    continue;
                                }
                                try {
                                    Tcp tcp =
                                            new Tcp(servers.get(ij), serverPorts.get(ij), retryTimeout);
                                    byte[] msg2;
                                    try {
                                        msg2 = doTcpQuery(tcp, pkt);
                                    } finally {
                                        tcp.close();
                                    }
                                    Header hdr2 = new Header(msg2, msg2.length);
                                    if (hdr2.query) {
                                        throw new DnsResolverException(
                                                "DNS error: expecting response");
                                    }
                                    checkResponseCode(hdr2);

                                    if (!auth || hdr2.authoritative) {
                                        // Got a valid response
                                        hdr = hdr2;
                                        msg = msg2;
                                        break;
                                    } else {
                                        doNotRetry[ij] = true;
                                    }
                                } catch (Exception e) {
                                    // Try next server, or use UDP response
                                }
                            } // servers
                        }
                        return new ResourceRecords(msg, msg.length, hdr, false);

                    } catch (PortUnreachableException e) {
                        if (caughtException == null) {
                            caughtException = e;
                        }
                        doNotRetry[i] = true;
                    } catch (IOException e) {
                        if (DEBUG) {
                            dprint("Caught IOException:" + e);
                        }
                        if (caughtException == null) {
                            caughtException = e;
                        }
                    } catch (DnsNameNotFoundException e) {
                        // This is authoritative, so return immediately
                        throw e;
                    } catch (DnsCommunicationException e) {
                        if (caughtException == null) {
                            caughtException = e;
                        }
                    } catch (DnsResolverException e) {
                        if (caughtException == null) {
                            caughtException = e;
                        }
                        doNotRetry[i] = true;
                    }
                } // servers
            } // retries

        } finally {
            reqs.remove(xid); // cleanup
        }

        if (caughtException instanceof DnsResolverException) {
            throw (DnsResolverException) caughtException;
        }
        // A network timeout or other error occurred.
        throw new DnsResolverException("DNS error", caughtException);
    }

    /**
     * Tries to retrieve a UDP packet matching the given xid
     * received within the timeout.
     * If a packet with different xid is received, the received packet
     * is enqueued with the corresponding xid in 'resps'.
     */
    private byte[] doUdpQuery(Packet pkt, InetAddress server,
                              int port, int retry, int xid)
            throws IOException, DnsResolverException {

        int minTimeout = 50; // msec after which there are no retries.


        try (DatagramChannel dc = getDatagramChannel()) {
            DatagramPacket opkt = new DatagramPacket(pkt.getData(), pkt.length(), server, port);
            DatagramPacket ipkt = new DatagramPacket(new byte[8000], 8000);
            // Packets may only be sent to or received from this server address
            // TODO: Revisit
            var pa = (PrivilegedAction<Void>) () -> {
                dc.socket().connect(server, port);
                return null;
            };
            if (System.getSecurityManager() == null) {
                pa.run();
            } else {
                AccessController.doPrivileged(pa);
            }


            int pktTimeout = (timeout * (1 << retry));
            try {
                dc.socket().send(opkt);

                // timeout remaining after successive 'receive()'
                int timeoutLeft = pktTimeout;
                int cnt = 0;
                do {
                    if (DEBUG) {
                        cnt++;
                        dprint("Trying RECEIVE(" +
                                cnt + ") retry(" + (retry + 1) +
                                ") for:" + xid + "    sock-timeout:" +
                                timeoutLeft + " ms.");
                    }
                    dc.socket().setSoTimeout(timeoutLeft);
                    long start = System.currentTimeMillis();


                    dc.socket().receive(ipkt);
                    byte[] data = ipkt.getData();
                    int length = ipkt.getLength();
                    long end = System.currentTimeMillis();

                    if (isMatchResponse(data, length, xid)) {
                        return data;
                    }
                    timeoutLeft = pktTimeout - ((int) (end - start));
                } while (timeoutLeft > minTimeout);

            } finally {
                dc.disconnect();
            }
            return null; // no matching packet received within the timeout
        }
    }

    /*
     * Sends a TCP query, and returns the first DNS message in the response.
     */
    private byte[] doTcpQuery(Tcp tcp, Packet pkt) throws IOException {

        int len = pkt.length();
        // Send 2-byte message length, then send message.
        tcp.out.write(len >> 8);
        tcp.out.write(len);
        tcp.out.write(pkt.getData(), 0, len);
        tcp.out.flush();

        byte[] msg = continueTcpQuery(tcp);
        if (msg == null) {
            throw new IOException("DNS error: no response");
        }
        return msg;
    }

    /*
     * Returns the next DNS message from the TCP socket, or null on EOF.
     */
    private byte[] continueTcpQuery(Tcp tcp) throws IOException {

        int lenHi = tcp.read();      // high-order byte of response length
        if (lenHi == -1) {
            return null;        // EOF
        }
        int lenLo = tcp.read();      // low-order byte of response length
        if (lenLo == -1) {
            throw new IOException("Corrupted DNS response: bad length");
        }
        int len = (lenHi << 8) | lenLo;
        byte[] msg = new byte[len];
        int pos = 0;                    // next unfilled position in msg
        while (len > 0) {
            int n = tcp.read(msg, pos, len);
            if (n == -1) {
                throw new IOException(
                        "Corrupted DNS response: too little data");
            }
            len -= n;
            pos += n;
        }
        return msg;
    }

    private Packet makeQueryPacket(DnsName fqdn, int xid,
                                   int qclass, int qtype, boolean recursion) {
        int qnameLen = fqdn.getOctets();
        int pktLen = DNS_HDR_SIZE + qnameLen + 4;
        Packet pkt = new Packet(pktLen);

        short flags = recursion ? Header.RD_BIT : 0;
        // flags = (short) (flags | Header.CD_BIT | Header.AD_BIT);

        pkt.putShort(xid, IDENT_OFFSET);
        pkt.putShort(flags, FLAGS_OFFSET);
        pkt.putShort(1, NUMQ_OFFSET);
        pkt.putShort(0, NUMANS_OFFSET);
        pkt.putInt(0, NUMAUTH_OFFSET);

        makeQueryName(fqdn, pkt, DNS_HDR_SIZE);
        pkt.putShort(qtype, DNS_HDR_SIZE + qnameLen);
        pkt.putShort(qclass, DNS_HDR_SIZE + qnameLen + 2);

        return pkt;
    }

    // Builds a query name in pkt according to the RFC spec.
    private void makeQueryName(DnsName fqdn, Packet pkt, int off) {

        // Loop through labels, least-significant first.
        for (int i = fqdn.size() - 1; i >= 0; i--) {
            String label = fqdn.get(i);
            int len = label.length();

            pkt.putByte(len, off++);
            for (int j = 0; j < len; j++) {
                pkt.putByte(label.charAt(j), off++);
            }
        }
        if (!fqdn.hasRootLabel()) {
            pkt.putByte(0, off);
        }
    }

    //-------------------------------------------------------------------------

    private byte[] lookupResponse(Integer xid) throws DnsResolverException {
        //
        // Check the queued responses: some other thread in between
        // received the response for this request.
        //
        if (DEBUG) {
            dprint("LOOKUP for: " + xid +
                    "\tResponse Q:" + resps);
        }
        byte[] pkt;
        if ((pkt = resps.get(xid)) != null) {
            checkResponseCode(new Header(pkt, pkt.length));
            queuesLock.lock();
            try {
                resps.remove(xid);
                reqs.remove(xid);
            } finally {
                queuesLock.unlock();
            }

            if (DEBUG) {
                dprint("FOUND (" + Thread.currentThread() +
                        ") for:" + xid);
            }
        }
        return pkt;
    }

    /*
     * Checks the header of an incoming DNS response.
     * Returns true if it matches the given xid and throws a naming
     * exception, if appropriate, based on the response code.
     *
     * Also checks that the domain name, type and class in the response
     * match those in the original query.
     */
    private boolean isMatchResponse(byte[] pkt, int length, int xid)
            throws DnsResolverException {

        Header hdr = new Header(pkt, length);
        if (hdr.query) {
            throw new DnsResolverException("DNS error: expecting response");
        }

        if (!reqs.containsKey(xid)) { // already received, ignore the response
            return false;
        }

        // common case- the request sent matches the subsequent response read
        if (hdr.xid == xid) {
            if (DEBUG) {
                dprint("XID MATCH:" + xid);
            }
            checkResponseCode(hdr);
            if (!hdr.query && hdr.numQuestions == 1) {

                ResourceRecord rr = new ResourceRecord(pkt, length,
                        Header.HEADER_SIZE, true, false);

                // Retrieve the original query
                ResourceRecord query = reqs.get(xid);
                int qtype = query.getType();
                int qclass = query.getRrclass();
                DnsName qname = query.getName();

                // Check that the type/class/name in the query section of the
                // response match those in the original query
                if ((qtype == ResourceRecord.TYPE_ANY ||
                        qtype == rr.getType()) &&
                        (qclass == ResourceRecord.QCLASS_STAR ||
                                qclass == rr.getRrclass()) &&
                        qname.equals(rr.getName())) {

                    if (DEBUG) {
                        dprint("MATCH NAME:" + qname + " QTYPE:" + qtype +
                                " QCLASS:" + qclass);
                    }

                    // Remove the response for the xid if received by some other
                    // thread.
                    queuesLock.lock();
                    try {
                        resps.remove(xid);
                        reqs.remove(xid);
                    } finally {
                        queuesLock.unlock();
                    }
                    return true;

                } else {
                    if (DEBUG) {
                        dprint("NO-MATCH NAME:" + qname + " QTYPE:" + qtype +
                                " QCLASS:" + qclass);
                    }
                }
            }
            return false;
        }

        //
        // xid mis-match: enqueue the response, it may belong to some other
        // thread that has not yet had a chance to read its response.
        // enqueue only the first response, responses for retries are ignored.
        //
        queuesLock.lock();
        try {
            if (reqs.containsKey(hdr.xid)) { // enqueue only the first response
                resps.put(hdr.xid, Arrays.copyOf(pkt, length));
            }
        } finally {
            queuesLock.unlock();
        }

        if (DEBUG) {
            dprint("NO-MATCH SEND ID:" +
                    xid + " RECVD ID:" + hdr.xid +
                    "    Response Q:" + resps +
                    "    Reqs size:" + reqs.size());
        }
        return false;
    }

    /*
     * Throws an exception if appropriate for the response code of a
     * given header.
     */
    private void checkResponseCode(Header hdr) throws DnsResolverException {

        int rcode = hdr.rcode;
        if (rcode == NO_ERROR) {
            return;
        }
        String msg = (rcode < rcodeDescription.length)
                ? rcodeDescription[rcode]
                : "DNS error";

        msg += " [response code " + rcode + "]";
        throw new DnsResolverException(msg);
    }

    //-------------------------------------------------------------------------

    private static final boolean DEBUG = java.security.AccessController.doPrivileged(
            (PrivilegedAction<Boolean>) () -> Boolean.getBoolean("jdk.dns.client.debug"));

    private static void dprint(String mess) {
        if (DEBUG) {
            System.err.println("DNS: " + mess);
        }
    }

}

class Tcp {

    private final Socket sock;
    private final InputStream in;
    final OutputStream out;
    private int timeoutLeft;

    Tcp(InetAddress server, int port, int timeout) throws IOException {
        sock = new Socket();
        try {
            long start = System.currentTimeMillis();
            sock.connect(new InetSocketAddress(server, port), timeout);
            timeoutLeft = (int) (timeout - (System.currentTimeMillis() - start));
            if (timeoutLeft <= 0)
                throw new SocketTimeoutException();

            sock.setTcpNoDelay(true);
            out = new BufferedOutputStream(sock.getOutputStream());
            in = new BufferedInputStream(sock.getInputStream());
        } catch (Exception e) {
            try {
                sock.close();
            } catch (IOException ex) {
                e.addSuppressed(ex);
            }
            throw e;
        }
    }

    void close() throws IOException {
        sock.close();
    }

    private interface SocketReadOp {
        int read() throws IOException;
    }

    private int readWithTimeout(SocketReadOp reader) throws IOException {
        if (timeoutLeft <= 0)
            throw new SocketTimeoutException();

        sock.setSoTimeout(timeoutLeft);
        long start = System.currentTimeMillis();
        try {
            return reader.read();
        } finally {
            timeoutLeft -= System.currentTimeMillis() - start;
        }
    }

    int read() throws IOException {
        return readWithTimeout(in::read);
    }

    int read(byte b[], int off, int len) throws IOException {
        return readWithTimeout(() -> in.read(b, off, len));
    }
}

/*
 * javaos emulation -cj
 */
class Packet {
    byte[] buf;

    Packet(int len) {
        buf = new byte[len];
    }

    Packet(byte data[], int len) {
        buf = new byte[len];
        System.arraycopy(data, 0, buf, 0, len);
    }

    void putInt(int x, int off) {
        buf[off] = (byte) (x >> 24);
        buf[off + 1] = (byte) (x >> 16);
        buf[off + 2] = (byte) (x >> 8);
        buf[off + 3] = (byte) x;
    }

    void putShort(int x, int off) {
        buf[off] = (byte) (x >> 8);
        buf[off + 1] = (byte) x;
    }

    void putByte(int x, int off) {
        buf[off] = (byte) x;
    }

    void putBytes(byte src[], int src_offset, int dst_offset, int len) {
        System.arraycopy(src, src_offset, buf, dst_offset, len);
    }

    int length() {
        return buf.length;
    }

    byte[] getData() {
        return buf;
    }
}