1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import sun.security.util.HexDumpEncoder;
  25 
  26 import java.io.IOException;
  27 import java.net.DatagramPacket;
  28 import java.net.DatagramSocket;
  29 import java.net.InetAddress;
  30 import java.net.SocketException;
  31 import java.nio.ByteBuffer;
  32 import java.nio.file.Paths;
  33 import java.util.ArrayList;
  34 import java.util.Arrays;
  35 import java.util.List;
  36 import java.util.Scanner;
  37 import java.util.regex.MatchResult;
  38 
  39 /*
  40  * A dummy DNS server.
  41  *
  42  * Loads a sequence of DNS messages from a capture file into its cache.
  43  * It listens for DNS UDP requests, finds match request in cache and sends the
  44  * corresponding DNS responses.
  45  *
  46  * The capture file contains an DNS protocol exchange in the hexadecimal
  47  * dump format emitted by HexDumpEncoder:
  48  *
  49  * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
  50  *
  51  * Typically, DNS protocol exchange is generated by DNSTracer who captures
  52  * communication messages between DNS application program and real DNS server
  53  */
  54 public class DNSServer extends Thread implements Server {
  55 
  56     public class Pair<F, S> {
  57         private F first;
  58         private S second;
  59 
  60         public Pair(F first, S second) {
  61             this.first = first;
  62             this.second = second;
  63         }
  64 
  65         public void setFirst(F first) {
  66             this.first = first;
  67         }
  68 
  69         public void setSecond(S second) {
  70             this.second = second;
  71         }
  72 
  73         public F getFirst() {
  74             return first;
  75         }
  76 
  77         public S getSecond() {
  78             return second;
  79         }
  80     }
  81 
  82     public static final int DNS_HEADER_SIZE = 12;
  83     public static final int DNS_PACKET_SIZE = 512;
  84 
  85     static HexDumpEncoder encoder = new HexDumpEncoder();
  86 
  87     private DatagramSocket socket;
  88     private String filename;
  89     private boolean loop;
  90     private final List<Pair<byte[], byte[]>> cache = new ArrayList<>();
  91     private ByteBuffer reqBuffer = ByteBuffer.allocate(DNS_PACKET_SIZE);
  92     private volatile boolean isRunning;
  93 
  94     public DNSServer(String filename) throws SocketException {
  95         this(filename, false);
  96     }
  97 
  98     public DNSServer(String filename, boolean loop) throws SocketException {
  99         this.socket = new DatagramSocket(0, InetAddress.getLoopbackAddress());
 100         this.filename = filename;
 101         this.loop = loop;
 102     }
 103 
 104     public void run() {
 105         try {
 106             isRunning = true;
 107             System.out.println(
 108                     "DNSServer: Loading DNS cache data from : " + filename);
 109             loadCaptureFile(filename);
 110 
 111             System.out.println(
 112                     "DNSServer: listening on port " + socket.getLocalPort());
 113 
 114             System.out.println("DNSServer: loop playback: " + loop);
 115 
 116             int playbackIndex = 0;
 117 
 118             while (playbackIndex < cache.size()) {
 119                 DatagramPacket reqPacket = receiveQuery();
 120 
 121                 if (!verifyRequestMsg(reqPacket, playbackIndex)) {
 122                     if (playbackIndex > 0 && verifyRequestMsg(reqPacket,
 123                             playbackIndex - 1)) {
 124                         System.out.println(
 125                                 "DNSServer: received retry query, resend");
 126                         playbackIndex--;
 127                     } else {
 128                         throw new RuntimeException(
 129                                 "DNSServer: Error: Failed to verify DNS request. "
 130                                         + "Not identical request message : \n"
 131                                         + encoder.encodeBuffer(
 132                                         Arrays.copyOf(reqPacket.getData(),
 133                                                 reqPacket.getLength())));
 134                     }
 135                 }
 136 
 137                 sendResponse(reqPacket, playbackIndex);
 138 
 139                 playbackIndex++;
 140                 if (loop && playbackIndex >= cache.size()) {
 141                     playbackIndex = 0;
 142                 }
 143             }
 144 
 145             System.out.println(
 146                     "DNSServer: Done for all cached messages playback");
 147 
 148             System.out.println(
 149                     "DNSServer: Still listening for possible retry query");
 150             while (true) {
 151                 DatagramPacket reqPacket = receiveQuery();
 152 
 153                 // here we only handle the retry query for last one
 154                 if (!verifyRequestMsg(reqPacket, playbackIndex - 1)) {
 155                     throw new RuntimeException(
 156                             "DNSServer: Error: Failed to verify DNS request. "
 157                                     + "Not identical request message : \n"
 158                                     + encoder.encodeBuffer(
 159                                     Arrays.copyOf(reqPacket.getData(),
 160                                             reqPacket.getLength())));
 161                 }
 162 
 163                 sendResponse(reqPacket, playbackIndex - 1);
 164             }
 165         } catch (Exception e) {
 166             if (isRunning) {
 167                 System.err.println("DNSServer: Error: " + e);
 168                 e.printStackTrace();
 169             } else {
 170                 System.out.println("DNSServer: Exit");
 171             }
 172         }
 173     }
 174 
 175     private DatagramPacket receiveQuery() throws IOException {
 176         DatagramPacket reqPacket = new DatagramPacket(reqBuffer.array(),
 177                 reqBuffer.array().length);
 178         socket.receive(reqPacket);
 179 
 180         System.out.println("DNSServer: received query message from " + reqPacket
 181                 .getSocketAddress());
 182 
 183         return reqPacket;
 184     }
 185 
 186     private void sendResponse(DatagramPacket reqPacket, int playbackIndex)
 187             throws IOException {
 188         byte[] payload = generateResponsePayload(reqPacket, playbackIndex);
 189         socket.send(new DatagramPacket(payload, payload.length,
 190                 reqPacket.getSocketAddress()));
 191         System.out.println("DNSServer: send response message to " + reqPacket
 192                 .getSocketAddress());
 193     }
 194 
 195     /*
 196      * Load a capture file containing an DNS protocol exchange in the
 197      * hexadecimal dump format emitted by sun.misc.HexDumpEncoder:
 198      *
 199      * xxxx: 00 11 22 33 44 55 66 77   88 99 aa bb cc dd ee ff  ................
 200      */
 201     private void loadCaptureFile(String filename) throws IOException {
 202         StringBuilder hexString = new StringBuilder();
 203         String pattern = "(....): (..) (..) (..) (..) (..) (..) (..) (..)   "
 204                 + "(..) (..) (..) (..) (..) (..) (..) (..).*";
 205 
 206         try (Scanner fileScanner = new Scanner(Paths.get(filename))) {
 207             while (fileScanner.hasNextLine()) {
 208 
 209                 try (Scanner lineScanner = new Scanner(
 210                         fileScanner.nextLine())) {
 211                     if (lineScanner.findInLine(pattern) == null) {
 212                         continue;
 213                     }
 214                     MatchResult result = lineScanner.match();
 215                     for (int i = 1; i <= result.groupCount(); i++) {
 216                         String digits = result.group(i);
 217                         if (digits.length() == 4) {
 218                             if (digits.equals("0000")) { // start-of-message
 219                                 if (hexString.length() > 0) {
 220                                     addToCache(hexString.toString());
 221                                     hexString.delete(0, hexString.length());
 222                                 }
 223                             }
 224                             continue;
 225                         } else if (digits.equals("  ")) { // short message
 226                             continue;
 227                         }
 228                         hexString.append(digits);
 229                     }
 230                 }
 231             }
 232         }
 233         addToCache(hexString.toString());
 234     }
 235 
 236     /*
 237      * Add an DNS encoding to the cache (by request message key).
 238      */
 239     private void addToCache(String hexString) {
 240         byte[] encoding = parseHexBinary(hexString);
 241         if (encoding.length < DNS_HEADER_SIZE) {
 242             throw new RuntimeException("Invalid DNS message : " + hexString);
 243         }
 244 
 245         if (getQR(encoding) == 0) {
 246             // a query message, create entry in cache
 247             cache.add(new Pair<>(encoding, null));
 248             System.out.println(
 249                     "    adding DNS query message with ID " + getID(encoding)
 250                             + " to the cache");
 251         } else {
 252             // a response message, attach it to the query entry
 253             if (!cache.isEmpty() && (getID(getLatestCacheEntry().getFirst())
 254                     == getID(encoding))) {
 255                 getLatestCacheEntry().setSecond(encoding);
 256                 System.out.println(
 257                         "    adding DNS response message associated to ID "
 258                                 + getID(encoding) + " in the cache");
 259             } else {
 260                 throw new RuntimeException(
 261                         "Invalid DNS message : " + hexString);
 262             }
 263         }
 264     }
 265 
 266     /*
 267      * ID: A 16 bit identifier assigned by the program that generates any
 268      * kind of query. This identifier is copied the corresponding reply and
 269      * can be used by the requester to match up replies to outstanding queries.
 270      */
 271     private static int getID(byte[] encoding) {
 272         return ByteBuffer.wrap(encoding, 0, 2).getShort();
 273     }
 274 
 275     /*
 276      * QR: A one bit field that specifies whether this message is
 277      * a query (0), or a response (1) after ID
 278      */
 279     private static int getQR(byte[] encoding) {
 280         return encoding[2] & (0x01 << 7);
 281     }
 282 
 283     private Pair<byte[], byte[]> getLatestCacheEntry() {
 284         return cache.get(cache.size() - 1);
 285     }
 286 
 287     private boolean verifyRequestMsg(DatagramPacket packet, int playbackIndex) {
 288         byte[] cachedRequest = cache.get(playbackIndex).getFirst();
 289         return Arrays.equals(Arrays
 290                         .copyOfRange(packet.getData(), 2, packet.getLength()),
 291                 Arrays.copyOfRange(cachedRequest, 2, cachedRequest.length));
 292     }
 293 
 294     private byte[] generateResponsePayload(DatagramPacket packet,
 295             int playbackIndex) {
 296         byte[] resMsg = cache.get(playbackIndex).getSecond();
 297         byte[] payload = Arrays.copyOf(resMsg, resMsg.length);
 298 
 299         // replace the ID with same with real request
 300         payload[0] = packet.getData()[0];
 301         payload[1] = packet.getData()[1];
 302 
 303         return payload;
 304     }
 305 
 306     public static byte[] parseHexBinary(String s) {
 307 
 308         final int len = s.length();
 309 
 310         // "111" is not a valid hex encoding.
 311         if (len % 2 != 0) {
 312             throw new IllegalArgumentException(
 313                     "hexBinary needs to be even-length: " + s);
 314         }
 315 
 316         byte[] out = new byte[len / 2];
 317 
 318         for (int i = 0; i < len; i += 2) {
 319             int h = hexToBin(s.charAt(i));
 320             int l = hexToBin(s.charAt(i + 1));
 321             if (h == -1 || l == -1) {
 322                 throw new IllegalArgumentException(
 323                         "contains illegal character for hexBinary: " + s);
 324             }
 325 
 326             out[i / 2] = (byte) (h * 16 + l);
 327         }
 328 
 329         return out;
 330     }
 331 
 332     private static int hexToBin(char ch) {
 333         if ('0' <= ch && ch <= '9') {
 334             return ch - '0';
 335         }
 336         if ('A' <= ch && ch <= 'F') {
 337             return ch - 'A' + 10;
 338         }
 339         if ('a' <= ch && ch <= 'f') {
 340             return ch - 'a' + 10;
 341         }
 342         return -1;
 343     }
 344 
 345     @Override public void stopServer() {
 346         isRunning = false;
 347         if (socket != null) {
 348             try {
 349                 socket.close();
 350             } catch (Exception e) {
 351                 // ignore
 352             }
 353         }
 354     }
 355 
 356     @Override public int getPort() {
 357         if (socket != null) {
 358             return socket.getLocalPort();
 359         } else {
 360             return -1;
 361         }
 362     }
 363 }