001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.security;
020
021import java.io.ByteArrayInputStream;
022import java.io.DataInput;
023import java.io.DataInputStream;
024import java.io.DataOutput;
025import java.io.IOException;
026import java.nio.charset.StandardCharsets;
027import java.security.PrivilegedExceptionAction;
028import java.security.Security;
029import java.util.ArrayList;
030import java.util.Enumeration;
031import java.util.HashMap;
032import java.util.List;
033import java.util.Map;
034
035import javax.security.auth.callback.Callback;
036import javax.security.auth.callback.CallbackHandler;
037import javax.security.auth.callback.NameCallback;
038import javax.security.auth.callback.PasswordCallback;
039import javax.security.auth.callback.UnsupportedCallbackException;
040import javax.security.sasl.AuthorizeCallback;
041import javax.security.sasl.RealmCallback;
042import javax.security.sasl.Sasl;
043import javax.security.sasl.SaslException;
044import javax.security.sasl.SaslServer;
045import javax.security.sasl.SaslServerFactory;
046
047import org.apache.commons.codec.binary.Base64;
048import org.apache.commons.logging.Log;
049import org.apache.commons.logging.LogFactory;
050import org.apache.hadoop.classification.InterfaceAudience;
051import org.apache.hadoop.classification.InterfaceStability;
052import org.apache.hadoop.conf.Configuration;
053import org.apache.hadoop.ipc.RetriableException;
054import org.apache.hadoop.ipc.Server;
055import org.apache.hadoop.ipc.Server.Connection;
056import org.apache.hadoop.ipc.StandbyException;
057import org.apache.hadoop.security.token.SecretManager;
058import org.apache.hadoop.security.token.SecretManager.InvalidToken;
059import org.apache.hadoop.security.token.TokenIdentifier;
060
061/**
062 * A utility class for dealing with SASL on RPC server
063 */
064@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"})
065@InterfaceStability.Evolving
066public class SaslRpcServer {
067  public static final Log LOG = LogFactory.getLog(SaslRpcServer.class);
068  public static final String SASL_DEFAULT_REALM = "default";
069  private static SaslServerFactory saslFactory;
070
071  public static enum QualityOfProtection {
072    AUTHENTICATION("auth"),
073    INTEGRITY("auth-int"),
074    PRIVACY("auth-conf");
075    
076    public final String saslQop;
077    
078    private QualityOfProtection(String saslQop) {
079      this.saslQop = saslQop;
080    }
081    
082    public String getSaslQop() {
083      return saslQop;
084    }
085  }
086
087  @InterfaceAudience.Private
088  @InterfaceStability.Unstable
089  public AuthMethod authMethod;
090  public String mechanism;
091  public String protocol;
092  public String serverId;
093  
094  @InterfaceAudience.Private
095  @InterfaceStability.Unstable
096  public SaslRpcServer(AuthMethod authMethod) throws IOException {
097    this.authMethod = authMethod;
098    mechanism = authMethod.getMechanismName();    
099    switch (authMethod) {
100      case SIMPLE: {
101        return; // no sasl for simple
102      }
103      case TOKEN: {
104        protocol = "";
105        serverId = SaslRpcServer.SASL_DEFAULT_REALM;
106        break;
107      }
108      case KERBEROS: {
109        String fullName = UserGroupInformation.getCurrentUser().getUserName();
110        if (LOG.isDebugEnabled())
111          LOG.debug("Kerberos principal name is " + fullName);
112        // don't use KerberosName because we don't want auth_to_local
113        String[] parts = fullName.split("[/@]", 3);
114        protocol = parts[0];
115        // should verify service host is present here rather than in create()
116        // but lazy tests are using a UGI that isn't a SPN...
117        serverId = (parts.length < 2) ? "" : parts[1];
118        break;
119      }
120      default:
121        // we should never be able to get here
122        throw new AccessControlException(
123            "Server does not support SASL " + authMethod);
124    }
125  }
126  
127  @InterfaceAudience.Private
128  @InterfaceStability.Unstable
129  public SaslServer create(final Connection connection,
130                           final Map<String,?> saslProperties,
131                           SecretManager<TokenIdentifier> secretManager
132      ) throws IOException, InterruptedException {
133    UserGroupInformation ugi = null;
134    final CallbackHandler callback;
135    switch (authMethod) {
136      case TOKEN: {
137        callback = new SaslDigestCallbackHandler(secretManager, connection);
138        break;
139      }
140      case KERBEROS: {
141        ugi = UserGroupInformation.getCurrentUser();
142        if (serverId.isEmpty()) {
143          throw new AccessControlException(
144              "Kerberos principal name does NOT have the expected "
145                  + "hostname part: " + ugi.getUserName());
146        }
147        callback = new SaslGssCallbackHandler();
148        break;
149      }
150      default:
151        // we should never be able to get here
152        throw new AccessControlException(
153            "Server does not support SASL " + authMethod);
154    }
155    
156    final SaslServer saslServer;
157    if (ugi != null) {
158      saslServer = ugi.doAs(
159        new PrivilegedExceptionAction<SaslServer>() {
160          @Override
161          public SaslServer run() throws SaslException  {
162            return saslFactory.createSaslServer(mechanism, protocol, serverId,
163                saslProperties, callback);
164          }
165        });
166    } else {
167      saslServer = saslFactory.createSaslServer(mechanism, protocol, serverId,
168          saslProperties, callback);
169    }
170    if (saslServer == null) {
171      throw new AccessControlException(
172          "Unable to find SASL server implementation for " + mechanism);
173    }
174    if (LOG.isDebugEnabled()) {
175      LOG.debug("Created SASL server with mechanism = " + mechanism);
176    }
177    return saslServer;
178  }
179
180  public static void init(Configuration conf) {
181    Security.addProvider(new SaslPlainServer.SecurityProvider());
182    // passing null so factory is populated with all possibilities.  the
183    // properties passed when instantiating a server are what really matter
184    saslFactory = new FastSaslServerFactory(null);
185  }
186  
187  static String encodeIdentifier(byte[] identifier) {
188    return new String(Base64.encodeBase64(identifier), StandardCharsets.UTF_8);
189  }
190
191  static byte[] decodeIdentifier(String identifier) {
192    return Base64.decodeBase64(identifier.getBytes(StandardCharsets.UTF_8));
193  }
194
195  public static <T extends TokenIdentifier> T getIdentifier(String id,
196      SecretManager<T> secretManager) throws InvalidToken {
197    byte[] tokenId = decodeIdentifier(id);
198    T tokenIdentifier = secretManager.createIdentifier();
199    try {
200      tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(
201          tokenId)));
202    } catch (IOException e) {
203      throw (InvalidToken) new InvalidToken(
204          "Can't de-serialize tokenIdentifier").initCause(e);
205    }
206    return tokenIdentifier;
207  }
208
209  static char[] encodePassword(byte[] password) {
210    return new String(Base64.encodeBase64(password),
211                      StandardCharsets.UTF_8).toCharArray();
212  }
213
214  /** Splitting fully qualified Kerberos name into parts */
215  public static String[] splitKerberosName(String fullName) {
216    return fullName.split("[/@]");
217  }
218
219  /** Authentication method */
220  @InterfaceStability.Evolving
221  public static enum AuthMethod {
222    SIMPLE((byte) 80, ""),
223    KERBEROS((byte) 81, "GSSAPI"),
224    @Deprecated
225    DIGEST((byte) 82, "DIGEST-MD5"),
226    TOKEN((byte) 82, "DIGEST-MD5"),
227    PLAIN((byte) 83, "PLAIN");
228
229    /** The code for this method. */
230    public final byte code;
231    public final String mechanismName;
232
233    private AuthMethod(byte code, String mechanismName) { 
234      this.code = code;
235      this.mechanismName = mechanismName;
236    }
237
238    private static final int FIRST_CODE = values()[0].code;
239
240    /** Return the object represented by the code. */
241    private static AuthMethod valueOf(byte code) {
242      final int i = (code & 0xff) - FIRST_CODE;
243      return i < 0 || i >= values().length ? null : values()[i];
244    }
245
246    /** Return the SASL mechanism name */
247    public String getMechanismName() {
248      return mechanismName;
249    }
250
251    /** Read from in */
252    public static AuthMethod read(DataInput in) throws IOException {
253      return valueOf(in.readByte());
254    }
255
256    /** Write to out */
257    public void write(DataOutput out) throws IOException {
258      out.write(code);
259    }
260  };
261
262  /** CallbackHandler for SASL DIGEST-MD5 mechanism */
263  @InterfaceStability.Evolving
264  public static class SaslDigestCallbackHandler implements CallbackHandler {
265    private SecretManager<TokenIdentifier> secretManager;
266    private Server.Connection connection; 
267    
268    public SaslDigestCallbackHandler(
269        SecretManager<TokenIdentifier> secretManager,
270        Server.Connection connection) {
271      this.secretManager = secretManager;
272      this.connection = connection;
273    }
274
275    private char[] getPassword(TokenIdentifier tokenid) throws InvalidToken,
276        StandbyException, RetriableException, IOException {
277      return encodePassword(secretManager.retriableRetrievePassword(tokenid));
278    }
279
280    @Override
281    public void handle(Callback[] callbacks) throws InvalidToken,
282        UnsupportedCallbackException, StandbyException, RetriableException,
283        IOException {
284      NameCallback nc = null;
285      PasswordCallback pc = null;
286      AuthorizeCallback ac = null;
287      for (Callback callback : callbacks) {
288        if (callback instanceof AuthorizeCallback) {
289          ac = (AuthorizeCallback) callback;
290        } else if (callback instanceof NameCallback) {
291          nc = (NameCallback) callback;
292        } else if (callback instanceof PasswordCallback) {
293          pc = (PasswordCallback) callback;
294        } else if (callback instanceof RealmCallback) {
295          continue; // realm is ignored
296        } else {
297          throw new UnsupportedCallbackException(callback,
298              "Unrecognized SASL DIGEST-MD5 Callback");
299        }
300      }
301      if (pc != null) {
302        TokenIdentifier tokenIdentifier = getIdentifier(nc.getDefaultName(),
303            secretManager);
304        char[] password = getPassword(tokenIdentifier);
305        UserGroupInformation user = null;
306        user = tokenIdentifier.getUser(); // may throw exception
307        connection.attemptingUser = user;
308        
309        if (LOG.isDebugEnabled()) {
310          LOG.debug("SASL server DIGEST-MD5 callback: setting password "
311              + "for client: " + tokenIdentifier.getUser());
312        }
313        pc.setPassword(password);
314      }
315      if (ac != null) {
316        String authid = ac.getAuthenticationID();
317        String authzid = ac.getAuthorizationID();
318        if (authid.equals(authzid)) {
319          ac.setAuthorized(true);
320        } else {
321          ac.setAuthorized(false);
322        }
323        if (ac.isAuthorized()) {
324          if (LOG.isDebugEnabled()) {
325            UserGroupInformation logUser =
326              getIdentifier(authzid, secretManager).getUser();
327            String username = logUser == null ? null : logUser.getUserName();
328            LOG.debug("SASL server DIGEST-MD5 callback: setting "
329                + "canonicalized client ID: " + username);
330          }
331          ac.setAuthorizedID(authzid);
332        }
333      }
334    }
335  }
336
337  /** CallbackHandler for SASL GSSAPI Kerberos mechanism */
338  @InterfaceStability.Evolving
339  public static class SaslGssCallbackHandler implements CallbackHandler {
340
341    @Override
342    public void handle(Callback[] callbacks) throws
343        UnsupportedCallbackException {
344      AuthorizeCallback ac = null;
345      for (Callback callback : callbacks) {
346        if (callback instanceof AuthorizeCallback) {
347          ac = (AuthorizeCallback) callback;
348        } else {
349          throw new UnsupportedCallbackException(callback,
350              "Unrecognized SASL GSSAPI Callback");
351        }
352      }
353      if (ac != null) {
354        String authid = ac.getAuthenticationID();
355        String authzid = ac.getAuthorizationID();
356        if (authid.equals(authzid)) {
357          ac.setAuthorized(true);
358        } else {
359          ac.setAuthorized(false);
360        }
361        if (ac.isAuthorized()) {
362          if (LOG.isDebugEnabled())
363            LOG.debug("SASL server GSSAPI callback: setting "
364                + "canonicalized client ID: " + authzid);
365          ac.setAuthorizedID(authzid);
366        }
367      }
368    }
369  }
370  
371  // Sasl.createSaslServer is 100-200X slower than caching the factories!
372  private static class FastSaslServerFactory implements SaslServerFactory {
373    private final Map<String,List<SaslServerFactory>> factoryCache =
374        new HashMap<String,List<SaslServerFactory>>();
375
376    FastSaslServerFactory(Map<String,?> props) {
377      final Enumeration<SaslServerFactory> factories =
378          Sasl.getSaslServerFactories();
379      while (factories.hasMoreElements()) {
380        SaslServerFactory factory = factories.nextElement();
381        for (String mech : factory.getMechanismNames(props)) {
382          if (!factoryCache.containsKey(mech)) {
383            factoryCache.put(mech, new ArrayList<SaslServerFactory>());
384          }
385          factoryCache.get(mech).add(factory);
386        }
387      }
388    }
389
390    @Override
391    public SaslServer createSaslServer(String mechanism, String protocol,
392        String serverName, Map<String,?> props, CallbackHandler cbh)
393        throws SaslException {
394      SaslServer saslServer = null;
395      List<SaslServerFactory> factories = factoryCache.get(mechanism);
396      if (factories != null) {
397        for (SaslServerFactory factory : factories) {
398          saslServer = factory.createSaslServer(
399              mechanism, protocol, serverName, props, cbh);
400          if (saslServer != null) {
401            break;
402          }
403        }
404      }
405      return saslServer;
406    }
407
408    @Override
409    public String[] getMechanismNames(Map<String, ?> props) {
410      return factoryCache.keySet().toArray(new String[0]);
411    }
412  }
413}