org.echocat.jomon.net.dns.DnsServer.java Source code

Java tutorial

Introduction

Here is the source code for org.echocat.jomon.net.dns.DnsServer.java

Source

/*****************************************************************************************
 * *** BEGIN LICENSE BLOCK *****
 *
 * Version: MPL 2.0
 *
 * echocat Jomon, Copyright (c) 2012-2013 echocat
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * *** END LICENSE BLOCK *****
 ****************************************************************************************/

package org.echocat.jomon.net.dns;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.*;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.*;
import java.net.*;
import java.util.*;

import static java.lang.Thread.currentThread;
import static org.apache.commons.io.IOUtils.closeQuietly;

/**
 * @author Brian Wellington <bwelling@xbill.org>
 */

@SuppressWarnings({ "ContinueStatement", "OverlyLongMethod", "MethodWithMultipleReturnPoints",
        "AssignmentToMethodParameter", "InfiniteLoopStatement", "StatementWithEmptyBody" })
public class DnsServer implements AutoCloseable {

    private static final Logger LOG = LoggerFactory.getLogger(DnsServer.class);

    private volatile boolean _closed;
    private final Set<Thread> _threads = new HashSet<>();
    private final Set<Closeable> _closeables = new HashSet<>();

    static final int FLAG_DNSSECOK = 1;
    static final int FLAG_SIGONLY = 2;

    private final Map<Integer, Cache> _caches;
    private final Map<Name, Zone> _znames;
    private final Map<Name, TSIG> _tsigs;

    private static String addrport(InetAddress addr, int port) {
        return addr.getHostAddress() + "#" + port;
    }

    public DnsServer() {
        this(null);
    }

    public DnsServer(@Nullable String config) {
        final BufferedReader reader = new BufferedReader(new StringReader(config));
        _caches = new HashMap<>();
        _znames = new HashMap<>();
        _tsigs = new HashMap<>();
        try {
            String line;
            while ((line = reader.readLine()) != null) {
                final StringTokenizer st = new StringTokenizer(line);
                if (!st.hasMoreTokens()) {
                    continue;
                }
                final String keyword = st.nextToken();
                if (!st.hasMoreTokens()) {
                    throw new IllegalArgumentException("Invalid line: " + line);
                }
                if (keyword.charAt(0) == '#') {
                    continue;
                }
                if (keyword.equals("primary")) {
                    addPrimaryZone(st.nextToken(), st.nextToken());
                } else if (keyword.equals("secondary")) {
                    addSecondaryZone(st.nextToken(), st.nextToken());
                } else if (keyword.equals("cache")) {
                    final Cache cache = new Cache(st.nextToken());
                    _caches.put(DClass.IN, cache);
                } else if (keyword.equals("key")) {
                    final String s1 = st.nextToken();
                    final String s2 = st.nextToken();
                    if (st.hasMoreTokens()) {
                        addTSIG(s1, s2, st.nextToken());
                    } else {
                        addTSIG("hmac-md5", s1, s2);
                    }
                } else {
                    throw new IllegalArgumentException("unknown keyword: " + keyword);
                }
            }
        } catch (IOException | ZoneTransferException e) {
            throw new RuntimeException("Could not create server.", e);
        }
        LOG.info("jnamed: running");
    }

    public void addPrimaryZone(String zname, String zonefile) throws IOException {
        Name origin = null;
        if (zname != null) {
            origin = Name.fromString(zname, Name.root);
        }
        final Zone newzone = new Zone(origin, zonefile);
        addPrimaryZone(newzone);
    }

    public void addPrimaryZone(@Nonnull Zone zone) throws IOException {
        _znames.put(zone.getOrigin(), zone);
    }

    public void addSecondaryZone(String zone, String remote) throws IOException, ZoneTransferException {
        final Name zname = Name.fromString(zone, Name.root);
        final Zone newzone = new Zone(zname, DClass.IN, remote);
        _znames.put(zname, newzone);
    }

    public void addSecondaryZone(@Nonnull Zone zone) {
        _znames.put(zone.getOrigin(), zone);
    }

    public void addTSIG(String algstr, String namestr, String key) throws IOException {
        final Name name = Name.fromString(namestr, Name.root);
        _tsigs.put(name, new TSIG(algstr, namestr, key));
    }

    public Cache getCache(int dclass) {
        Cache c = _caches.get(dclass);
        if (c == null) {
            c = new Cache(dclass);
            _caches.put(dclass, c);
        }
        return c;
    }

    public Zone findBestZone(Name name) {
        Zone foundzone;
        foundzone = _znames.get(name);
        if (foundzone != null) {
            return foundzone;
        }
        final int labels = name.labels();
        for (int i = 1; i < labels; i++) {
            final Name tname = new Name(name, i);
            foundzone = _znames.get(tname);
            if (foundzone != null) {
                return foundzone;
            }
        }
        return null;
    }

    public RRset findExactMatch(Name name, int type, int dclass, boolean glue) {
        final Zone zone = findBestZone(name);
        if (zone != null) {
            return zone.findExactMatch(name, type);
        } else {
            final RRset[] rrsets;
            final Cache cache = getCache(dclass);
            if (glue) {
                rrsets = cache.findAnyRecords(name, type);
            } else {
                rrsets = cache.findRecords(name, type);
            }
            if (rrsets == null) {
                return null;
            } else {
                return rrsets[0]; /* not quite right */
            }
        }
    }

    void addRRset(Name name, Message response, RRset rrset, int section, int flags) {
        for (int s = 1; s <= section; s++) {
            if (response.findRRset(name, rrset.getType(), s)) {
                return;
            }
        }
        if ((flags & FLAG_SIGONLY) == 0) {
            final Iterator<?> it = rrset.rrs();
            while (it.hasNext()) {
                Record r = (Record) it.next();
                if (r.getName().isWild() && !name.isWild()) {
                    r = r.withName(name);
                }
                response.addRecord(r, section);
            }
        }
        if ((flags & (FLAG_SIGONLY | FLAG_DNSSECOK)) != 0) {
            final Iterator<?> it = rrset.sigs();
            while (it.hasNext()) {
                Record r = (Record) it.next();
                if (r.getName().isWild() && !name.isWild()) {
                    r = r.withName(name);
                }
                response.addRecord(r, section);
            }
        }
    }

    private void addSOA(Message response, Zone zone) {
        response.addRecord(zone.getSOA(), Section.AUTHORITY);
    }

    private void addNS(Message response, Zone zone, int flags) {
        final RRset nsRecords = zone.getNS();
        addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
    }

    private void addCacheNS(Message response, Cache cache, Name name) {
        final SetResponse sr = cache.lookupRecords(name, Type.NS, Credibility.HINT);
        if (!sr.isDelegation()) {
            return;
        }
        final RRset nsRecords = sr.getNS();
        final Iterator<?> it = nsRecords.rrs();
        while (it.hasNext()) {
            final Record r = (Record) it.next();
            response.addRecord(r, Section.AUTHORITY);
        }
    }

    private void addGlue(Message response, Name name, int flags) {
        final RRset a = findExactMatch(name, Type.A, DClass.IN, true);
        if (a == null) {
            return;
        }
        addRRset(name, response, a, Section.ADDITIONAL, flags);
    }

    private void addAdditional2(Message response, int section, int flags) {
        final Record[] records = response.getSectionArray(section);
        for (int i = 0; i < records.length; i++) {
            final Record r = records[i];
            final Name glueName = r.getAdditionalName();
            if (glueName != null) {
                addGlue(response, glueName, flags);
            }
        }
    }

    private void addAdditional(Message response, int flags) {
        addAdditional2(response, Section.ANSWER, flags);
        addAdditional2(response, Section.AUTHORITY, flags);
    }

    byte addAnswer(Message response, Name name, int type, int dclass, int iterations, int flags) {
        final SetResponse sr;
        byte rcode = Rcode.NOERROR;

        if (iterations > 6) {
            return Rcode.NOERROR;
        }

        if (type == Type.SIG || type == Type.RRSIG) {
            type = Type.ANY;
            flags |= FLAG_SIGONLY;
        }

        final Zone zone = findBestZone(name);
        if (zone != null) {
            sr = zone.findRecords(name, type);
        } else {
            final Cache cache = getCache(dclass);
            sr = cache.lookupRecords(name, type, Credibility.NORMAL);
        }

        if (sr.isUnknown()) {
            addCacheNS(response, getCache(dclass), name);
        }
        if (sr.isNXDOMAIN()) {
            response.getHeader().setRcode(Rcode.NXDOMAIN);
            if (zone != null) {
                addSOA(response, zone);
                if (iterations == 0) {
                    response.getHeader().setFlag(Flags.AA);
                }
            }
            rcode = Rcode.NXDOMAIN;
        } else if (sr.isNXRRSET()) {
            if (zone != null) {
                addSOA(response, zone);
                if (iterations == 0) {
                    response.getHeader().setFlag(Flags.AA);
                }
            }
        } else if (sr.isDelegation()) {
            final RRset nsRecords = sr.getNS();
            addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
        } else if (sr.isCNAME()) {
            final CNAMERecord cname = sr.getCNAME();
            final RRset rrset = new RRset(cname);
            addRRset(name, response, rrset, Section.ANSWER, flags);
            if (zone != null && iterations == 0) {
                response.getHeader().setFlag(Flags.AA);
            }
            rcode = addAnswer(response, cname.getTarget(), type, dclass, iterations + 1, flags);
        } else if (sr.isDNAME()) {
            final DNAMERecord dname = sr.getDNAME();
            RRset rrset = new RRset(dname);
            addRRset(name, response, rrset, Section.ANSWER, flags);
            final Name newname;
            try {
                newname = name.fromDNAME(dname);
            } catch (final NameTooLongException ignored) {
                return Rcode.YXDOMAIN;
            }
            rrset = new RRset(new CNAMERecord(name, dclass, 0, newname));
            addRRset(name, response, rrset, Section.ANSWER, flags);
            if (zone != null && iterations == 0) {
                response.getHeader().setFlag(Flags.AA);
            }
            rcode = addAnswer(response, newname, type, dclass, iterations + 1, flags);
        } else if (sr.isSuccessful()) {
            final RRset[] rrsets = sr.answers();
            for (int i = 0; i < rrsets.length; i++) {
                addRRset(name, response, rrsets[i], Section.ANSWER, flags);
            }
            if (zone != null) {
                addNS(response, zone, flags);
                if (iterations == 0) {
                    response.getHeader().setFlag(Flags.AA);
                }
            } else {
                addCacheNS(response, getCache(dclass), name);
            }
        }
        return rcode;
    }

    byte[] doAXFR(Name name, Message query, TSIG tsig, TSIGRecord qtsig, Socket s) {
        final Zone zone = _znames.get(name);
        boolean first = true;
        if (zone == null) {
            return errorMessage(query, Rcode.REFUSED);
        }
        final Iterator<?> it = zone.AXFR();
        try {
            final DataOutputStream dataOut;
            dataOut = new DataOutputStream(s.getOutputStream());
            final int id = query.getHeader().getID();
            while (it.hasNext()) {
                final RRset rrset = (RRset) it.next();
                final Message response = new Message(id);
                final Header header = response.getHeader();
                header.setFlag(Flags.QR);
                header.setFlag(Flags.AA);
                addRRset(rrset.getName(), response, rrset, Section.ANSWER, FLAG_DNSSECOK);
                if (tsig != null) {
                    tsig.applyStream(response, qtsig, first);
                    qtsig = response.getTSIG();
                }
                first = false;
                final byte[] out = response.toWire();
                dataOut.writeShort(out.length);
                dataOut.write(out);
            }
        } catch (final IOException ignored) {
            LOG.info("AXFR failed");
        }
        closeQuietly(s);
        return null;
    }

    /*
    * Note: a null return value means that the caller doesn't need to do
    * anything.  Currently this only happens if this is an AXFR request over
    * TCP.
    */
    byte[] generateReply(Message query, byte[] in, int length, Socket s) throws IOException {
        final Header header;
        final int maxLength;
        int flags = 0;

        header = query.getHeader();
        if (header.getFlag(Flags.QR)) {
            return null;
        }
        if (header.getRcode() != Rcode.NOERROR) {
            return errorMessage(query, Rcode.FORMERR);
        }
        if (header.getOpcode() != Opcode.QUERY) {
            return errorMessage(query, Rcode.NOTIMP);
        }

        final Record queryRecord = query.getQuestion();

        final TSIGRecord queryTSIG = query.getTSIG();
        TSIG tsig = null;
        if (queryTSIG != null) {
            tsig = _tsigs.get(queryTSIG.getName());
            if (tsig == null || tsig.verify(query, in, length, null) != Rcode.NOERROR) {
                return formerrMessage(in);
            }
        }

        final OPTRecord queryOPT = query.getOPT();
        if (queryOPT != null && queryOPT.getVersion() > 0) {
        }

        if (s != null) {
            maxLength = 65535;
        } else if (queryOPT != null) {
            maxLength = Math.max(queryOPT.getPayloadSize(), 512);
        } else {
            maxLength = 512;
        }

        if (queryOPT != null && (queryOPT.getFlags() & ExtendedFlags.DO) != 0) {
            flags = FLAG_DNSSECOK;
        }

        final Message response = new Message(query.getHeader().getID());
        response.getHeader().setFlag(Flags.QR);
        if (query.getHeader().getFlag(Flags.RD)) {
            response.getHeader().setFlag(Flags.RD);
        }
        response.addRecord(queryRecord, Section.QUESTION);

        final Name name = queryRecord.getName();
        final int type = queryRecord.getType();
        final int dclass = queryRecord.getDClass();
        if (type == Type.AXFR && s != null) {
            return doAXFR(name, query, tsig, queryTSIG, s);
        }
        if (!Type.isRR(type) && type != Type.ANY) {
            return errorMessage(query, Rcode.NOTIMP);
        }

        final byte rcode = addAnswer(response, name, type, dclass, 0, flags);
        if (rcode != Rcode.NOERROR && rcode != Rcode.NXDOMAIN) {
            return errorMessage(query, rcode);
        }

        addAdditional(response, flags);

        if (queryOPT != null) {
            final int optflags = (flags == FLAG_DNSSECOK) ? ExtendedFlags.DO : 0;
            final OPTRecord opt = new OPTRecord((short) 4096, rcode, (byte) 0, optflags);
            response.addRecord(opt, Section.ADDITIONAL);
        }

        response.setTSIG(tsig, Rcode.NOERROR, queryTSIG);
        return response.toWire(maxLength);
    }

    byte[] buildErrorMessage(Header header, int rcode, Record question) {
        final Message response = new Message();
        response.setHeader(header);
        for (int i = 0; i < 4; i++) {
            response.removeAllRecords(i);
        }
        if (rcode == Rcode.SERVFAIL) {
            response.addRecord(question, Section.QUESTION);
        }
        header.setRcode(rcode);
        return response.toWire();
    }

    public byte[] formerrMessage(byte[] in) {
        final Header header;
        try {
            header = new Header(in);
        } catch (final IOException ignored) {
            return null;
        }
        return buildErrorMessage(header, Rcode.FORMERR, null);
    }

    public byte[] errorMessage(Message query, int rcode) {
        return buildErrorMessage(query.getHeader(), rcode, query.getQuestion());
    }

    public void TCPclient(Socket s) {
        try {
            final int inLength;
            final DataInputStream dataIn;
            final DataOutputStream dataOut;
            final byte[] in;

            final InputStream is = s.getInputStream();
            dataIn = new DataInputStream(is);
            inLength = dataIn.readUnsignedShort();
            in = new byte[inLength];
            dataIn.readFully(in);

            final Message query;
            byte[] response;
            try {
                query = new Message(in);
                response = generateReply(query, in, in.length, s);
                if (response == null) {
                    return;
                }
            } catch (final IOException ignored) {
                response = formerrMessage(in);
            }
            dataOut = new DataOutputStream(s.getOutputStream());
            dataOut.writeShort(response.length);
            dataOut.write(response);
        } catch (final IOException e) {
            LOG.warn("TCPclient(" + addrport(s.getLocalAddress(), s.getLocalPort()) + ").", e);
        } finally {
            try {
                s.close();
            } catch (final IOException ignored) {
            }
        }
    }

    public void serveTCP(InetSocketAddress address) {
        try {
            final ServerSocket sock = new ServerSocket(address.getPort(), 128, address.getAddress());
            synchronized (_closeables) {
                _closeables.add(sock);
            }
            while (!currentThread().isInterrupted()) {
                final Socket s = accept(sock);
                final Thread thread;
                thread = new Thread(new Runnable() {
                    @Override
                    public void run() {
                        TCPclient(s);
                    }
                });
                _threads.add(thread);
                thread.start();
            }
        } catch (final InterruptedIOException ignored) {
            currentThread().interrupt();
        } catch (final IOException e) {
            LOG.warn("serveTCP(" + addrport(address.getAddress(), address.getPort()) + ")", e);
        }
    }

    @Nonnull
    private Socket accept(@Nonnull ServerSocket sock) throws IOException {
        try {
            return sock.accept();
        } catch (final SocketException e) {
            if (sock.isClosed()) {
                final InterruptedIOException toThrow = new InterruptedIOException();
                toThrow.initCause(e);
                throw toThrow;
            } else {
                throw e;
            }
        }
    }

    public void serveUDP(InetSocketAddress address) {
        try {
            final DatagramSocket sock = new DatagramSocket(address.getPort(), address.getAddress());
            synchronized (_closeables) {
                _closeables.add(sock);
            }
            final short udpLength = 512;
            final byte[] in = new byte[udpLength];
            final DatagramPacket indp = new DatagramPacket(in, in.length);
            DatagramPacket outdp = null;
            while (!currentThread().isInterrupted()) {
                indp.setLength(in.length);
                receive(sock, indp);
                final Message query;
                byte[] response;
                try {
                    query = new Message(in);
                    response = generateReply(query, in, indp.getLength(), null);
                    if (response == null) {
                        continue;
                    }
                } catch (final IOException ignored) {
                    response = formerrMessage(in);
                }
                if (outdp == null) {
                    outdp = new DatagramPacket(response, response.length, indp.getAddress(), indp.getPort());
                } else {
                    outdp.setData(response);
                    outdp.setLength(response.length);
                    outdp.setAddress(indp.getAddress());
                    outdp.setPort(indp.getPort());
                }
                sock.send(outdp);
            }
        } catch (final InterruptedIOException ignored) {
            currentThread().interrupt();
        } catch (final IOException e) {
            LOG.warn("serveUDP(" + addrport(address.getAddress(), address.getPort()) + ")", e);
        }
    }

    private static void receive(@Nonnull DatagramSocket sock, @Nonnull DatagramPacket indp) throws IOException {
        try {
            sock.receive(indp);
        } catch (final SocketException e) {
            if (sock.isClosed()) {
                final InterruptedIOException toThrow = new InterruptedIOException();
                toThrow.initCause(e);
                throw toThrow;
            } else {
                throw e;
            }
        }
    }

    public void addTCP(final InetSocketAddress address) {
        synchronized (this) {
            assertNotClosed();
            final Thread t;
            t = new Thread(new Runnable() {
                @Override
                public void run() {
                    serveTCP(address);
                }
            });
            _threads.add(t);
            t.start();
        }
    }

    public void addUDP(final InetSocketAddress address) {
        synchronized (this) {
            assertNotClosed();
            final Thread t;
            t = new Thread(new Runnable() {
                @Override
                public void run() {
                    serveUDP(address);
                }
            });
            _threads.add(t);
            t.start();
        }
    }

    protected void assertNotClosed() {
        if (_closed) {
            throw new IllegalStateException("This server was already closed.");
        }
    }

    @Override
    public void close() throws Exception {
        synchronized (this) {
            _closed = true;
            synchronized (_closeables) {
                for (final Closeable closeable : _closeables) {
                    closeQuietly(closeable);
                }
            }
            for (final Thread thread : _threads) {
                do {
                    thread.interrupt();
                    try {
                        thread.join(10);
                    } catch (final InterruptedException ignored) {
                        LOG.info("Got interrupted and could not wait for end of '" + thread + "'.");
                        currentThread().interrupt();
                    }
                } while (!currentThread().isInterrupted() && thread.isAlive());
            }
        }
    }
}