001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements.  See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership.  The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License.  You may obtain a copy of the License at
009     *
010     *     http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing, software
013     * distributed under the License is distributed on an "AS IS" BASIS,
014     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015     * See the License for the specific language governing permissions and
016     * limitations under the License.
017     */
018    
019    package org.apache.hadoop.security;
020    
021    import java.io.ByteArrayInputStream;
022    import java.io.DataInput;
023    import java.io.DataInputStream;
024    import java.io.DataOutput;
025    import java.io.IOException;
026    import java.security.PrivilegedExceptionAction;
027    import java.security.Security;
028    import java.util.ArrayList;
029    import java.util.Enumeration;
030    import java.util.HashMap;
031    import java.util.List;
032    import java.util.Map;
033    
034    import javax.security.auth.callback.Callback;
035    import javax.security.auth.callback.CallbackHandler;
036    import javax.security.auth.callback.NameCallback;
037    import javax.security.auth.callback.PasswordCallback;
038    import javax.security.auth.callback.UnsupportedCallbackException;
039    import javax.security.sasl.AuthorizeCallback;
040    import javax.security.sasl.RealmCallback;
041    import javax.security.sasl.Sasl;
042    import javax.security.sasl.SaslException;
043    import javax.security.sasl.SaslServer;
044    import javax.security.sasl.SaslServerFactory;
045    
046    import org.apache.commons.codec.binary.Base64;
047    import org.apache.commons.logging.Log;
048    import org.apache.commons.logging.LogFactory;
049    import org.apache.hadoop.classification.InterfaceAudience;
050    import org.apache.hadoop.classification.InterfaceStability;
051    import org.apache.hadoop.conf.Configuration;
052    import org.apache.hadoop.ipc.RetriableException;
053    import org.apache.hadoop.ipc.Server;
054    import org.apache.hadoop.ipc.Server.Connection;
055    import org.apache.hadoop.ipc.StandbyException;
056    import org.apache.hadoop.security.token.SecretManager;
057    import org.apache.hadoop.security.token.SecretManager.InvalidToken;
058    import org.apache.hadoop.security.token.TokenIdentifier;
059    
060    /**
061     * A utility class for dealing with SASL on RPC server
062     */
063    @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"})
064    @InterfaceStability.Evolving
065    public class SaslRpcServer {
066      public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
067      public static final String SASL_DEFAULT_REALM = "default";
068      private static SaslServerFactory saslFactory;
069    
070      public static enum QualityOfProtection {
071        AUTHENTICATION("auth"),
072        INTEGRITY("auth-int"),
073        PRIVACY("auth-conf");
074        
075        public final String saslQop;
076        
077        private QualityOfProtection(String saslQop) {
078          this.saslQop = saslQop;
079        }
080        
081        public String getSaslQop() {
082          return saslQop;
083        }
084      }
085    
086      @InterfaceAudience.Private
087      @InterfaceStability.Unstable
088      public AuthMethod authMethod;
089      public String mechanism;
090      public String protocol;
091      public String serverId;
092      
093      @InterfaceAudience.Private
094      @InterfaceStability.Unstable
095      public SaslRpcServer(AuthMethod authMethod) throws IOException {
096        this.authMethod = authMethod;
097        mechanism = authMethod.getMechanismName();    
098        switch (authMethod) {
099          case SIMPLE: {
100            return; // no sasl for simple
101          }
102          case TOKEN: {
103            protocol = "";
104            serverId = SaslRpcServer.SASL_DEFAULT_REALM;
105            break;
106          }
107          case KERBEROS: {
108            String fullName = UserGroupInformation.getCurrentUser().getUserName();
109            if (LOG.isDebugEnabled())
110              LOG.debug("Kerberos principal name is " + fullName);
111            // don't use KerberosName because we don't want auth_to_local
112            String[] parts = fullName.split("[/@]", 3);
113            protocol = parts[0];
114            // should verify service host is present here rather than in create()
115            // but lazy tests are using a UGI that isn't a SPN...
116            serverId = (parts.length < 2) ? "" : parts[1];
117            break;
118          }
119          default:
120            // we should never be able to get here
121            throw new AccessControlException(
122                "Server does not support SASL " + authMethod);
123        }
124      }
125      
126      @InterfaceAudience.Private
127      @InterfaceStability.Unstable
128      public SaslServer create(final Connection connection,
129                               final Map<String,?> saslProperties,
130                               SecretManager<TokenIdentifier> secretManager
131          ) throws IOException, InterruptedException {
132        UserGroupInformation ugi = null;
133        final CallbackHandler callback;
134        switch (authMethod) {
135          case TOKEN: {
136            callback = new SaslDigestCallbackHandler(secretManager, connection);
137            break;
138          }
139          case KERBEROS: {
140            ugi = UserGroupInformation.getCurrentUser();
141            if (serverId.isEmpty()) {
142              throw new AccessControlException(
143                  "Kerberos principal name does NOT have the expected "
144                      + "hostname part: " + ugi.getUserName());
145            }
146            callback = new SaslGssCallbackHandler();
147            break;
148          }
149          default:
150            // we should never be able to get here
151            throw new AccessControlException(
152                "Server does not support SASL " + authMethod);
153        }
154        
155        final SaslServer saslServer;
156        if (ugi != null) {
157          saslServer = ugi.doAs(
158            new PrivilegedExceptionAction<SaslServer>() {
159              @Override
160              public SaslServer run() throws SaslException  {
161                return saslFactory.createSaslServer(mechanism, protocol, serverId,
162                    saslProperties, callback);
163              }
164            });
165        } else {
166          saslServer = saslFactory.createSaslServer(mechanism, protocol, serverId,
167              saslProperties, callback);
168        }
169        if (saslServer == null) {
170          throw new AccessControlException(
171              "Unable to find SASL server implementation for " + mechanism);
172        }
173        if (LOG.isDebugEnabled()) {
174          LOG.debug("Created SASL server with mechanism = " + mechanism);
175        }
176        return saslServer;
177      }
178    
179      public static void init(Configuration conf) {
180        Security.addProvider(new SaslPlainServer.SecurityProvider());
181        // passing null so factory is populated with all possibilities.  the
182        // properties passed when instantiating a server are what really matter
183        saslFactory = new FastSaslServerFactory(null);
184      }
185      
186      static String encodeIdentifier(byte[] identifier) {
187        return new String(Base64.encodeBase64(identifier));
188      }
189    
190      static byte[] decodeIdentifier(String identifier) {
191        return Base64.decodeBase64(identifier.getBytes());
192      }
193    
194      public static <T extends TokenIdentifier> T getIdentifier(String id,
195          SecretManager<T> secretManager) throws InvalidToken {
196        byte[] tokenId = decodeIdentifier(id);
197        T tokenIdentifier = secretManager.createIdentifier();
198        try {
199          tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
200              tokenId)));
201        } catch (IOException e) {
202          throw (InvalidToken) new InvalidToken(
203              "Can't de-serialize tokenIdentifier").initCause(e);
204        }
205        return tokenIdentifier;
206      }
207    
208      static char[] encodePassword(byte[] password) {
209        return new String(Base64.encodeBase64(password)).toCharArray();
210      }
211    
212      /** Splitting fully qualified Kerberos name into parts */
213      public static String[] splitKerberosName(String fullName) {
214        return fullName.split("[/@]");
215      }
216    
217      /** Authentication method */
218      @InterfaceStability.Evolving
219      public static enum AuthMethod {
220        SIMPLE((byte) 80, ""),
221        KERBEROS((byte) 81, "GSSAPI"),
222        @Deprecated
223        DIGEST((byte) 82, "DIGEST-MD5"),
224        TOKEN((byte) 82, "DIGEST-MD5"),
225        PLAIN((byte) 83, "PLAIN");
226    
227        /** The code for this method. */
228        public final byte code;
229        public final String mechanismName;
230    
231        private AuthMethod(byte code, String mechanismName) { 
232          this.code = code;
233          this.mechanismName = mechanismName;
234        }
235    
236        private static final int FIRST_CODE = values()[0].code;
237    
238        /** Return the object represented by the code. */
239        private static AuthMethod valueOf(byte code) {
240          final int i = (code & 0xff) - FIRST_CODE;
241          return i < 0 || i >= values().length ? null : values()[i];
242        }
243    
244        /** Return the SASL mechanism name */
245        public String getMechanismName() {
246          return mechanismName;
247        }
248    
249        /** Read from in */
250        public static AuthMethod read(DataInput in) throws IOException {
251          return valueOf(in.readByte());
252        }
253    
254        /** Write to out */
255        public void write(DataOutput out) throws IOException {
256          out.write(code);
257        }
258      };
259    
260      /** CallbackHandler for SASL DIGEST-MD5 mechanism */
261      @InterfaceStability.Evolving
262      public static class SaslDigestCallbackHandler implements CallbackHandler {
263        private SecretManager<TokenIdentifier> secretManager;
264        private Server.Connection connection; 
265        
266        public SaslDigestCallbackHandler(
267            SecretManager<TokenIdentifier> secretManager,
268            Server.Connection connection) {
269          this.secretManager = secretManager;
270          this.connection = connection;
271        }
272    
273        private char[] getPassword(TokenIdentifier tokenid) throws InvalidToken,
274            StandbyException, RetriableException, IOException {
275          return encodePassword(secretManager.retriableRetrievePassword(tokenid));
276        }
277    
278        @Override
279        public void handle(Callback[] callbacks) throws InvalidToken,
280            UnsupportedCallbackException, StandbyException, RetriableException,
281            IOException {
282          NameCallback nc = null;
283          PasswordCallback pc = null;
284          AuthorizeCallback ac = null;
285          for (Callback callback : callbacks) {
286            if (callback instanceof AuthorizeCallback) {
287              ac = (AuthorizeCallback) callback;
288            } else if (callback instanceof NameCallback) {
289              nc = (NameCallback) callback;
290            } else if (callback instanceof PasswordCallback) {
291              pc = (PasswordCallback) callback;
292            } else if (callback instanceof RealmCallback) {
293              continue; // realm is ignored
294            } else {
295              throw new UnsupportedCallbackException(callback,
296                  "Unrecognized SASL DIGEST-MD5 Callback");
297            }
298          }
299          if (pc != null) {
300            TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName(),
301                secretManager);
302            char[] password = getPassword(tokenIdentifier);
303            UserGroupInformation user = null;
304            user = tokenIdentifier.getUser(); // may throw exception
305            connection.attemptingUser = user;
306            
307            if (LOG.isDebugEnabled()) {
308              LOG.debug("SASL server DIGEST-MD5 callback: setting password "
309                  + "for client: " + tokenIdentifier.getUser());
310            }
311            pc.setPassword(password);
312          }
313          if (ac != null) {
314            String authid = ac.getAuthenticationID();
315            String authzid = ac.getAuthorizationID();
316            if (authid.equals(authzid)) {
317              ac.setAuthorized(true);
318            } else {
319              ac.setAuthorized(false);
320            }
321            if (ac.isAuthorized()) {
322              if (LOG.isDebugEnabled()) {
323                String username =
324                  getIdentifier(authzid, secretManager).getUser().getUserName();
325                LOG.debug("SASL server DIGEST-MD5 callback: setting "
326                    + "canonicalized client ID: " + username);
327              }
328              ac.setAuthorizedID(authzid);
329            }
330          }
331        }
332      }
333    
334      /** CallbackHandler for SASL GSSAPI Kerberos mechanism */
335      @InterfaceStability.Evolving
336      public static class SaslGssCallbackHandler implements CallbackHandler {
337    
338        @Override
339        public void handle(Callback[] callbacks) throws
340            UnsupportedCallbackException {
341          AuthorizeCallback ac = null;
342          for (Callback callback : callbacks) {
343            if (callback instanceof AuthorizeCallback) {
344              ac = (AuthorizeCallback) callback;
345            } else {
346              throw new UnsupportedCallbackException(callback,
347                  "Unrecognized SASL GSSAPI Callback");
348            }
349          }
350          if (ac != null) {
351            String authid = ac.getAuthenticationID();
352            String authzid = ac.getAuthorizationID();
353            if (authid.equals(authzid)) {
354              ac.setAuthorized(true);
355            } else {
356              ac.setAuthorized(false);
357            }
358            if (ac.isAuthorized()) {
359              if (LOG.isDebugEnabled())
360                LOG.debug("SASL server GSSAPI callback: setting "
361                    + "canonicalized client ID: " + authzid);
362              ac.setAuthorizedID(authzid);
363            }
364          }
365        }
366      }
367      
368      // Sasl.createSaslServer is 100-200X slower than caching the factories!
369      private static class FastSaslServerFactory implements SaslServerFactory {
370        private final Map<String,List<SaslServerFactory>> factoryCache =
371            new HashMap<String,List<SaslServerFactory>>();
372    
373        FastSaslServerFactory(Map<String,?> props) {
374          final Enumeration<SaslServerFactory> factories =
375              Sasl.getSaslServerFactories();
376          while (factories.hasMoreElements()) {
377            SaslServerFactory factory = factories.nextElement();
378            for (String mech : factory.getMechanismNames(props)) {
379              if (!factoryCache.containsKey(mech)) {
380                factoryCache.put(mech, new ArrayList<SaslServerFactory>());
381              }
382              factoryCache.get(mech).add(factory);
383            }
384          }
385        }
386    
387        @Override
388        public SaslServer createSaslServer(String mechanism, String protocol,
389            String serverName, Map<String,?> props, CallbackHandler cbh)
390            throws SaslException {
391          SaslServer saslServer = null;
392          List<SaslServerFactory> factories = factoryCache.get(mechanism);
393          if (factories != null) {
394            for (SaslServerFactory factory : factories) {
395              saslServer = factory.createSaslServer(
396                  mechanism, protocol, serverName, props, cbh);
397              if (saslServer != null) {
398                break;
399              }
400            }
401          }
402          return saslServer;
403        }
404    
405        @Override
406        public String[] getMechanismNames(Map<String, ?> props) {
407          return factoryCache.keySet().toArray(new String[0]);
408        }
409      }
410    }