1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import sun.security.util.HexDumpEncoder;
  25 
  26 import java.io.IOException;
  27 import java.net.DatagramPacket;
  28 import java.net.DatagramSocket;
  29 import java.nio.ByteBuffer;
  30 import java.nio.file.Paths;
  31 import java.util.ArrayList;
  32 import java.util.Arrays;
  33 import java.util.List;
  34 import java.util.Scanner;
  35 import java.util.regex.MatchResult;
  36 
  37 /*
  38  * A dummy DNS server.
  39  *
  40  * Loads a sequence of DNS messages from a capture file into its cache.
  41  * It listens for DNS UDP requests, finds match request in cache and sends the
  42  * corresponding DNS responses.
  43  *
  44  * The capture file contains an DNS protocol exchange in the hexadecimal
  45  * dump format emitted by HexDumpEncoder:
  46  *
  47  * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
  48  *
  49  * Typically, DNS protocol exchange is generated by DNSTracer who captures
  50  * communication messages between DNS application program and real DNS server
  51  */
  52 public class DNSServer implements Runnable {
  53 
  54     public class Pair<F, S> {
  55         private F first;
  56         private S second;
  57 
  58         public Pair(F first, S second) {
  59             this.first = first;
  60             this.second = second;
  61         }
  62 
  63         public void setFirst(F first) {
  64             this.first = first;
  65         }
  66 
  67         public void setSecond(S second) {
  68             this.second = second;
  69         }
  70 
  71         public F getFirst() {
  72             return first;
  73         }
  74 
  75         public S getSecond() {
  76             return second;
  77         }
  78     }
  79 
  80     public static final int DNS_HEADER_SIZE = 12;
  81     public static final int DNS_PACKET_SIZE = 512;
  82 
  83     static HexDumpEncoder encoder = new HexDumpEncoder();
  84 
  85     private DatagramSocket socket;
  86     private String filename;
  87     private boolean loop;
  88     private final List<Pair<byte[], byte[]>> cache = new ArrayList<>();
  89     private ByteBuffer reqBuffer = ByteBuffer.allocate(DNS_PACKET_SIZE);
  90 
  91     public DNSServer(DatagramSocket socket, String filename) {
  92         this(socket, filename, false);
  93     }
  94 
  95     public DNSServer(DatagramSocket socket, String filename, boolean loop) {
  96         this.socket = socket;
  97         this.filename = filename;
  98         this.loop = loop;
  99     }
 100 
 101     public void run() {
 102         try {
 103             System.out.println(
 104                     "DNSServer: Loading DNS cache data from : " + filename);
 105             loadCaptureFile(filename);
 106 
 107             System.out.println(
 108                     "DNSServer: listening on port " + socket.getLocalPort());
 109 
 110             System.out.println("DNSServer: loop playback: " + loop);
 111 
 112             int playbackIndex = 0;
 113 
 114             while (playbackIndex < cache.size()) {
 115                 DatagramPacket reqPacket = new DatagramPacket(reqBuffer.array(),
 116                         reqBuffer.array().length);
 117                 socket.receive(reqPacket);
 118 
 119                 System.out.println(
 120                         "DNSServer: received query message from " + reqPacket
 121                                 .getSocketAddress());
 122 
 123                 if (!verifyRequestMsg(reqPacket, playbackIndex)) {
 124                     throw new RuntimeException(
 125                             "DNSServer: Error: Failed to verify DNS request. "
 126                                     + "Not identical request message : \n"
 127                                     + encoder.encodeBuffer(
 128                                     Arrays.copyOf(reqPacket.getData(),
 129                                             reqPacket.getLength())));
 130                 }
 131 
 132                 byte[] payload = generateResponsePayload(reqPacket,
 133                         playbackIndex);
 134                 socket.send(new DatagramPacket(payload, payload.length,
 135                         reqPacket.getSocketAddress()));
 136                 System.out.println(
 137                         "DNSServer: send response message to " + reqPacket
 138                                 .getSocketAddress());
 139 
 140                 playbackIndex++;
 141                 if (loop && playbackIndex >= cache.size()) {
 142                     playbackIndex = 0;
 143                 }
 144             }
 145 
 146             System.out.println(
 147                     "DNSServer: Done for all cached messages playback");
 148         } catch (Exception e) {
 149             System.err.println("DNSServer: Error: " + e);
 150         }
 151     }
 152 
 153     /*
 154      * Load a capture file containing an DNS protocol exchange in the
 155      * hexadecimal dump format emitted by sun.misc.HexDumpEncoder:
 156      *
 157      * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
 158      */
 159     private void loadCaptureFile(String filename) throws IOException {
 160         StringBuilder hexString = new StringBuilder();
 161         String pattern = "(....): (..) (..) (..) (..) (..) (..) (..) (..)   "
 162                 + "(..) (..) (..) (..) (..) (..) (..) (..).*";
 163 
 164         try (Scanner fileScanner = new Scanner(Paths.get(filename))) {
 165             while (fileScanner.hasNextLine()) {
 166 
 167                 try (Scanner lineScanner = new Scanner(
 168                         fileScanner.nextLine())) {
 169                     if (lineScanner.findInLine(pattern) == null) {
 170                         continue;
 171                     }
 172                     MatchResult result = lineScanner.match();
 173                     for (int i = 1; i <= result.groupCount(); i++) {
 174                         String digits = result.group(i);
 175                         if (digits.length() == 4) {
 176                             if (digits.equals("0000")) { // start-of-message
 177                                 if (hexString.length() > 0) {
 178                                     addToCache(hexString.toString());
 179                                     hexString.delete(0, hexString.length());
 180                                 }
 181                             }
 182                             continue;
 183                         } else if (digits.equals("  ")) { // short message
 184                             continue;
 185                         }
 186                         hexString.append(digits);
 187                     }
 188                 }
 189             }
 190         }
 191         addToCache(hexString.toString());
 192     }
 193 
 194     /*
 195      * Add an DNS encoding to the cache (by request message key).
 196      */
 197     private void addToCache(String hexString) {
 198         byte[] encoding = parseHexBinary(hexString);
 199         if (encoding.length < DNS_HEADER_SIZE) {
 200             throw new RuntimeException("Invalid DNS message : " + hexString);
 201         }
 202 
 203         if (getQR(encoding) == 0) {
 204             // a query message, create entry in cache
 205             cache.add(new Pair<>(encoding, null));
 206             System.out.println(
 207                     "    adding DNS query message with ID " + getID(encoding)
 208                             + " to the cache");
 209         } else {
 210             // a response message, attach it to the query entry
 211             if (!cache.isEmpty() && (getID(getLatestCacheEntry().getFirst())
 212                     == getID(encoding))) {
 213                 getLatestCacheEntry().setSecond(encoding);
 214                 System.out.println(
 215                         "    adding DNS response message associated to ID "
 216                                 + getID(encoding) + " in the cache");
 217             } else {
 218                 throw new RuntimeException(
 219                         "Invalid DNS message : " + hexString);
 220             }
 221         }
 222     }
 223 
 224     /*
 225      * ID: A 16 bit identifier assigned by the program that generates any
 226      * kind of query. This identifier is copied the corresponding reply and
 227      * can be used by the requester to match up replies to outstanding queries.
 228      */
 229     private static int getID(byte[] encoding) {
 230         return ByteBuffer.wrap(encoding, 0, 2).getShort();
 231     }
 232 
 233     /*
 234      * QR: A one bit field that specifies whether this message is
 235      * a query (0), or a response (1) after ID
 236      */
 237     private static int getQR(byte[] encoding) {
 238         return encoding[2] & (0x01 << 7);
 239     }
 240 
 241     private Pair<byte[], byte[]> getLatestCacheEntry() {
 242         return cache.get(cache.size() - 1);
 243     }
 244 
 245     private boolean verifyRequestMsg(DatagramPacket packet, int playbackIndex) {
 246         byte[] cachedRequest = cache.get(playbackIndex).getFirst();
 247         return Arrays.equals(Arrays
 248                         .copyOfRange(packet.getData(), 2, packet.getLength()),
 249                 Arrays.copyOfRange(cachedRequest, 2, cachedRequest.length));
 250     }
 251 
 252     private byte[] generateResponsePayload(DatagramPacket packet,
 253             int playbackIndex) {
 254         byte[] resMsg = cache.get(playbackIndex).getSecond();
 255         byte[] payload = Arrays.copyOf(resMsg, resMsg.length);
 256 
 257         // replace the ID with same with real request
 258         payload[0] = packet.getData()[0];
 259         payload[1] = packet.getData()[1];
 260 
 261         return payload;
 262     }
 263 
 264     public static byte[] parseHexBinary(String s) {
 265 
 266         final int len = s.length();
 267 
 268         // "111" is not a valid hex encoding.
 269         if (len % 2 != 0) {
 270             throw new IllegalArgumentException(
 271                     "hexBinary needs to be even-length: " + s);
 272         }
 273 
 274         byte[] out = new byte[len / 2];
 275 
 276         for (int i = 0; i < len; i += 2) {
 277             int h = hexToBin(s.charAt(i));
 278             int l = hexToBin(s.charAt(i + 1));
 279             if (h == -1 || l == -1) {
 280                 throw new IllegalArgumentException(
 281                         "contains illegal character for hexBinary: " + s);
 282             }
 283 
 284             out[i / 2] = (byte) (h * 16 + l);
 285         }
 286 
 287         return out;
 288     }
 289 
 290     private static int hexToBin(char ch) {
 291         if ('0' <= ch && ch <= '9') {
 292             return ch - '0';
 293         }
 294         if ('A' <= ch && ch <= 'F') {
 295             return ch - 'A' + 10;
 296         }
 297         if ('a' <= ch && ch <= 'f') {
 298             return ch - 'a' + 10;
 299         }
 300         return -1;
 301     }
 302 }