001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    package org.apache.hadoop.net.unix;
019    
020    import java.io.Closeable;
021    import java.io.EOFException;
022    
023    import org.apache.hadoop.classification.InterfaceAudience;
024    import org.apache.hadoop.io.IOUtils;
025    
026    import java.io.IOException;
027    import java.nio.channels.ClosedChannelException;
028    import java.util.Iterator;
029    import java.util.LinkedList;
030    import java.util.TreeMap;
031    import java.util.Map;
032    import java.util.concurrent.locks.Condition;
033    import java.util.concurrent.locks.ReentrantLock;
034    
035    import org.apache.commons.lang.SystemUtils;
036    import org.apache.commons.logging.Log;
037    import org.apache.commons.logging.LogFactory;
038    import org.apache.hadoop.util.NativeCodeLoader;
039    
040    import com.google.common.annotations.VisibleForTesting;
041    import com.google.common.base.Preconditions;
042    import com.google.common.util.concurrent.Uninterruptibles;
043    
044    /**
045     * The DomainSocketWatcher watches a set of domain sockets to see when they
046     * become readable, or closed.  When one of those events happens, it makes a
047     * callback.
048     *
049     * See {@link DomainSocket} for more information about UNIX domain sockets.
050     */
051    @InterfaceAudience.LimitedPrivate("HDFS")
052    public final class DomainSocketWatcher implements Closeable {
053      static {
054        if (SystemUtils.IS_OS_WINDOWS) {
055          loadingFailureReason = "UNIX Domain sockets are not available on Windows.";
056        } else if (!NativeCodeLoader.isNativeCodeLoaded()) {
057          loadingFailureReason = "libhadoop cannot be loaded.";
058        } else {
059          String problem;
060          try {
061            anchorNative();
062            problem = null;
063          } catch (Throwable t) {
064            problem = "DomainSocketWatcher#anchorNative got error: " +
065              t.getMessage();
066          }
067          loadingFailureReason = problem;
068        }
069      }
070    
071      static Log LOG = LogFactory.getLog(DomainSocketWatcher.class);
072    
073      /**
074       * The reason why DomainSocketWatcher is not available, or null if it is
075       * available.
076       */
077      private final static String loadingFailureReason;
078    
079      /**
080       * Initializes the native library code.
081       */
082      private static native void anchorNative();
083    
084      public static String getLoadingFailureReason() {
085        return loadingFailureReason;
086      }
087    
088      public interface Handler {
089        /**
090         * Handles an event on a socket.  An event may be the socket becoming
091         * readable, or the remote end being closed.
092         *
093         * @param sock    The socket that the event occurred on.
094         * @return        Whether we should close the socket.
095         */
096        boolean handle(DomainSocket sock);
097      }
098    
099      /**
100       * Handler for {DomainSocketWatcher#notificationSockets[1]}
101       */
102      private class NotificationHandler implements Handler {
103        public boolean handle(DomainSocket sock) {
104          try {
105            if (LOG.isTraceEnabled()) {
106              LOG.trace(this + ": NotificationHandler: doing a read on " +
107                sock.fd);
108            }
109            if (sock.getInputStream().read() == -1) {
110              if (LOG.isTraceEnabled()) {
111                LOG.trace(this + ": NotificationHandler: got EOF on " + sock.fd);
112              }
113              throw new EOFException();
114            }
115            if (LOG.isTraceEnabled()) {
116              LOG.trace(this + ": NotificationHandler: read succeeded on " +
117                sock.fd);
118            }
119            return false;
120          } catch (IOException e) {
121            if (LOG.isTraceEnabled()) {
122              LOG.trace(this + ": NotificationHandler: setting closed to " +
123                  "true for " + sock.fd);
124            }
125            closed = true;
126            return true;
127          }
128        }
129      }
130    
131      private static class Entry {
132        final DomainSocket socket;
133        final Handler handler;
134    
135        Entry(DomainSocket socket, Handler handler) {
136          this.socket = socket;
137          this.handler = handler;
138        }
139    
140        DomainSocket getDomainSocket() {
141          return socket;
142        }
143    
144        Handler getHandler() {
145          return handler;
146        }
147      }
148    
149      /**
150       * The FdSet is a set of file descriptors that gets passed to poll(2).
151       * It contains a native memory segment, so that we don't have to copy
152       * in the poll0 function.
153       */
154      private static class FdSet {
155        private long data;
156    
157        private native static long alloc0();
158    
159        FdSet() {
160          data = alloc0();
161        }
162    
163        /**
164         * Add a file descriptor to the set.
165         *
166         * @param fd   The file descriptor to add.
167         */
168        native void add(int fd);
169    
170        /**
171         * Remove a file descriptor from the set.
172         *
173         * @param fd   The file descriptor to remove.
174         */
175        native void remove(int fd);
176    
177        /**
178         * Get an array containing all the FDs marked as readable.
179         * Also clear the state of all FDs.
180         *
181         * @return     An array containing all of the currently readable file
182         *             descriptors.
183         */
184        native int[] getAndClearReadableFds();
185    
186        /**
187         * Close the object and de-allocate the memory used.
188         */
189        native void close();
190      }
191    
192      /**
193       * Lock which protects toAdd, toRemove, and closed.
194       */
195      private final ReentrantLock lock = new ReentrantLock();
196    
197      /**
198       * Condition variable which indicates that toAdd and toRemove have been
199       * processed.
200       */
201      private final Condition processedCond = lock.newCondition();
202    
203      /**
204       * Entries to add.
205       */
206      private final LinkedList<Entry> toAdd =
207          new LinkedList<Entry>();
208    
209      /**
210       * Entries to remove.
211       */
212      private final TreeMap<Integer, DomainSocket> toRemove =
213          new TreeMap<Integer, DomainSocket>();
214    
215      /**
216       * Maximum length of time to go between checking whether the interrupted
217       * bit has been set for this thread.
218       */
219      private final int interruptCheckPeriodMs;
220    
221      /**
222       * A pair of sockets used to wake up the thread after it has called poll(2).
223       */
224      private final DomainSocket notificationSockets[];
225    
226      /**
227       * Whether or not this DomainSocketWatcher is closed.
228       */
229      private boolean closed = false;
230    
231      public DomainSocketWatcher(int interruptCheckPeriodMs) throws IOException {
232        if (loadingFailureReason != null) {
233          throw new UnsupportedOperationException(loadingFailureReason);
234        }
235        Preconditions.checkArgument(interruptCheckPeriodMs > 0);
236        this.interruptCheckPeriodMs = interruptCheckPeriodMs;
237        notificationSockets = DomainSocket.socketpair();
238        watcherThread.setDaemon(true);
239        watcherThread.start();
240      }
241    
242      /**
243       * Close the DomainSocketWatcher and wait for its thread to terminate.
244       *
245       * If there is more than one close, all but the first will be ignored.
246       */
247      @Override
248      public void close() throws IOException {
249        lock.lock();
250        try {
251          if (closed) return;
252          if (LOG.isDebugEnabled()) {
253            LOG.debug(this + ": closing");
254          }
255          closed = true;
256        } finally {
257          lock.unlock();
258        }
259        // Close notificationSockets[0], so that notificationSockets[1] gets an EOF
260        // event.  This will wake up the thread immediately if it is blocked inside
261        // the select() system call.
262        notificationSockets[0].close();
263        // Wait for the select thread to terminate.
264        Uninterruptibles.joinUninterruptibly(watcherThread);
265      }
266    
267      @VisibleForTesting
268      public boolean isClosed() {
269        lock.lock();
270        try {
271          return closed;
272        } finally {
273          lock.unlock();
274        }
275      }
276    
277      /**
278       * Add a socket.
279       *
280       * @param sock     The socket to add.  It is an error to re-add a socket that
281       *                   we are already watching.
282       * @param handler  The handler to associate with this socket.  This may be
283       *                   called any time after this function is called.
284       */
285      public void add(DomainSocket sock, Handler handler) {
286        lock.lock();
287        try {
288          if (closed) {
289            handler.handle(sock);
290            IOUtils.cleanup(LOG, sock);
291            return;
292          }
293          Entry entry = new Entry(sock, handler);
294          try {
295            sock.refCount.reference();
296          } catch (ClosedChannelException e1) {
297            // If the socket is already closed before we add it, invoke the
298            // handler immediately.  Then we're done.
299            handler.handle(sock);
300            return;
301          }
302          toAdd.add(entry);
303          kick();
304          while (true) {
305            try {
306              processedCond.await();
307            } catch (InterruptedException e) {
308              Thread.currentThread().interrupt();
309            }
310            if (!toAdd.contains(entry)) {
311              break;
312            }
313          }
314        } finally {
315          lock.unlock();
316        }
317      }
318    
319      /**
320       * Remove a socket.  Its handler will be called.
321       *
322       * @param sock     The socket to remove.
323       */
324      public void remove(DomainSocket sock) {
325        lock.lock();
326        try {
327          if (closed) return;
328          toRemove.put(sock.fd, sock);
329          kick();
330          while (true) {
331            try {
332              processedCond.await();
333            } catch (InterruptedException e) {
334              Thread.currentThread().interrupt();
335            }
336            if (!toRemove.containsKey(sock.fd)) {
337              break;
338            }
339          }
340        } finally {
341          lock.unlock();
342        }
343      }
344    
345      /**
346       * Wake up the DomainSocketWatcher thread.
347       */
348      private void kick() {
349        try {
350          notificationSockets[0].getOutputStream().write(0);
351        } catch (IOException e) {
352          if (!closed) {
353            LOG.error(this + ": error writing to notificationSockets[0]", e);
354          }
355        }
356      }
357    
358      private void sendCallback(String caller, TreeMap<Integer, Entry> entries,
359          FdSet fdSet, int fd) {
360        if (LOG.isTraceEnabled()) {
361          LOG.trace(this + ": " + caller + " starting sendCallback for fd " + fd);
362        }
363        Entry entry = entries.get(fd);
364        Preconditions.checkNotNull(entry,
365            this + ": fdSet contained " + fd + ", which we were " +
366            "not tracking.");
367        DomainSocket sock = entry.getDomainSocket();
368        if (entry.getHandler().handle(sock)) {
369          if (LOG.isTraceEnabled()) {
370            LOG.trace(this + ": " + caller + ": closing fd " + fd +
371                " at the request of the handler.");
372          }
373          if (toRemove.remove(fd) != null) {
374            if (LOG.isTraceEnabled()) {
375              LOG.trace(this + ": " + caller + " : sendCallback processed fd " +
376                fd  + " in toRemove.");
377            }
378          }
379          try {
380            sock.refCount.unreferenceCheckClosed();
381          } catch (IOException e) {
382            Preconditions.checkArgument(false,
383                this + ": file descriptor " + sock.fd + " was closed while " +
384                "still in the poll(2) loop.");
385          }
386          IOUtils.cleanup(LOG, sock);
387          entries.remove(fd);
388          fdSet.remove(fd);
389        } else {
390          if (LOG.isTraceEnabled()) {
391            LOG.trace(this + ": " + caller + ": sendCallback not " +
392                "closing fd " + fd);
393          }
394        }
395      }
396    
397      @VisibleForTesting
398      final Thread watcherThread = new Thread(new Runnable() {
399        @Override
400        public void run() {
401          if (LOG.isDebugEnabled()) {
402            LOG.debug(this + ": starting with interruptCheckPeriodMs = " +
403                interruptCheckPeriodMs);
404          }
405          final TreeMap<Integer, Entry> entries = new TreeMap<Integer, Entry>();
406          FdSet fdSet = new FdSet();
407          addNotificationSocket(entries, fdSet);
408          try {
409            while (true) {
410              lock.lock();
411              try {
412                for (int fd : fdSet.getAndClearReadableFds()) {
413                  sendCallback("getAndClearReadableFds", entries, fdSet, fd);
414                }
415                if (!(toAdd.isEmpty() && toRemove.isEmpty())) {
416                  // Handle pending additions (before pending removes).
417                  for (Iterator<Entry> iter = toAdd.iterator(); iter.hasNext(); ) {
418                    Entry entry = iter.next();
419                    DomainSocket sock = entry.getDomainSocket();
420                    Entry prevEntry = entries.put(sock.fd, entry);
421                    Preconditions.checkState(prevEntry == null,
422                        this + ": tried to watch a file descriptor that we " +
423                        "were already watching: " + sock);
424                    if (LOG.isTraceEnabled()) {
425                      LOG.trace(this + ": adding fd " + sock.fd);
426                    }
427                    fdSet.add(sock.fd);
428                    iter.remove();
429                  }
430                  // Handle pending removals
431                  while (true) {
432                    Map.Entry<Integer, DomainSocket> entry = toRemove.firstEntry();
433                    if (entry == null) break;
434                    sendCallback("handlePendingRemovals",
435                        entries, fdSet, entry.getValue().fd);
436                  }
437                  processedCond.signalAll();
438                }
439                // Check if the thread should terminate.  Doing this check now is
440                // easier than at the beginning of the loop, since we know toAdd and
441                // toRemove are now empty and processedCond has been notified if it
442                // needed to be.
443                if (closed) {
444                  if (LOG.isDebugEnabled()) {
445                    LOG.debug(toString() + " thread terminating.");
446                  }
447                  return;
448                }
449                // Check if someone sent our thread an InterruptedException while we
450                // were waiting in poll().
451                if (Thread.interrupted()) {
452                  throw new InterruptedException();
453                }
454              } finally {
455                lock.unlock();
456              }
457              doPoll0(interruptCheckPeriodMs, fdSet);
458            }
459          } catch (InterruptedException e) {
460            LOG.info(toString() + " terminating on InterruptedException");
461          } catch (IOException e) {
462            LOG.error(toString() + " terminating on IOException", e);
463          } finally {
464            kick(); // allow the handler for notificationSockets[0] to read a byte
465            for (Entry entry : entries.values()) {
466              sendCallback("close", entries, fdSet, entry.getDomainSocket().fd);
467            }
468            entries.clear();
469            fdSet.close();
470          }
471        }
472      });
473    
474      private void addNotificationSocket(final TreeMap<Integer, Entry> entries,
475          FdSet fdSet) {
476        entries.put(notificationSockets[1].fd, 
477            new Entry(notificationSockets[1], new NotificationHandler()));
478        try {
479          notificationSockets[1].refCount.reference();
480        } catch (IOException e) {
481          throw new RuntimeException(e);
482        }
483        fdSet.add(notificationSockets[1].fd);
484        if (LOG.isTraceEnabled()) {
485          LOG.trace(this + ": adding notificationSocket " +
486              notificationSockets[1].fd + ", connected to " +
487              notificationSockets[0].fd);
488        }
489      }
490    
491      public String toString() {
492        return "DomainSocketWatcher(" + System.identityHashCode(this) + ")"; 
493      }
494    
495      private static native int doPoll0(int maxWaitMs, FdSet readFds)
496          throws IOException;
497    }