001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018
019package org.apache.hadoop.security.http;
020
021import java.io.IOException;
022import java.util.ArrayList;
023import java.util.Arrays;
024import java.util.List;
025import java.util.regex.Matcher;
026import java.util.regex.Pattern;
027
028import javax.servlet.Filter;
029import javax.servlet.FilterChain;
030import javax.servlet.FilterConfig;
031import javax.servlet.ServletException;
032import javax.servlet.ServletRequest;
033import javax.servlet.ServletResponse;
034import javax.servlet.http.HttpServletRequest;
035import javax.servlet.http.HttpServletResponse;
036
037import org.apache.commons.lang.StringUtils;
038import org.apache.commons.logging.Log;
039import org.apache.commons.logging.LogFactory;
040
041import com.google.common.annotations.VisibleForTesting;
042
043public class CrossOriginFilter implements Filter {
044
045  private static final Log LOG = LogFactory.getLog(CrossOriginFilter.class);
046
047  // HTTP CORS Request Headers
048  static final String ORIGIN = "Origin";
049  static final String ACCESS_CONTROL_REQUEST_METHOD =
050      "Access-Control-Request-Method";
051  static final String ACCESS_CONTROL_REQUEST_HEADERS =
052      "Access-Control-Request-Headers";
053
054  // HTTP CORS Response Headers
055  static final String ACCESS_CONTROL_ALLOW_ORIGIN =
056      "Access-Control-Allow-Origin";
057  static final String ACCESS_CONTROL_ALLOW_CREDENTIALS =
058      "Access-Control-Allow-Credentials";
059  static final String ACCESS_CONTROL_ALLOW_METHODS =
060      "Access-Control-Allow-Methods";
061  static final String ACCESS_CONTROL_ALLOW_HEADERS =
062      "Access-Control-Allow-Headers";
063  static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age";
064
065  // Filter configuration
066  public static final String ALLOWED_ORIGINS = "allowed-origins";
067  public static final String ALLOWED_ORIGINS_DEFAULT = "*";
068  public static final String ALLOWED_METHODS = "allowed-methods";
069  public static final String ALLOWED_METHODS_DEFAULT = "GET,POST,HEAD";
070  public static final String ALLOWED_HEADERS = "allowed-headers";
071  public static final String ALLOWED_HEADERS_DEFAULT =
072      "X-Requested-With,Content-Type,Accept,Origin";
073  public static final String MAX_AGE = "max-age";
074  public static final String MAX_AGE_DEFAULT = "1800";
075
076  private List<String> allowedMethods = new ArrayList<String>();
077  private List<String> allowedHeaders = new ArrayList<String>();
078  private List<String> allowedOrigins = new ArrayList<String>();
079  private boolean allowAllOrigins = true;
080  private String maxAge;
081
082  @Override
083  public void init(FilterConfig filterConfig) throws ServletException {
084    initializeAllowedMethods(filterConfig);
085    initializeAllowedHeaders(filterConfig);
086    initializeAllowedOrigins(filterConfig);
087    initializeMaxAge(filterConfig);
088  }
089
090  @Override
091  public void doFilter(ServletRequest req, ServletResponse res,
092      FilterChain chain)
093      throws IOException, ServletException {
094    doCrossFilter((HttpServletRequest) req, (HttpServletResponse) res);
095    chain.doFilter(req, res);
096  }
097
098  @Override
099  public void destroy() {
100    allowedMethods.clear();
101    allowedHeaders.clear();
102    allowedOrigins.clear();
103  }
104
105  private void doCrossFilter(HttpServletRequest req, HttpServletResponse res) {
106
107    String originsList = encodeHeader(req.getHeader(ORIGIN));
108    if (!isCrossOrigin(originsList)) {
109      if(LOG.isDebugEnabled()) {
110        LOG.debug("Header origin is null. Returning");
111      }
112      return;
113    }
114
115    if (!areOriginsAllowed(originsList)) {
116      if(LOG.isDebugEnabled()) {
117        LOG.debug("Header origins '" + originsList + "' not allowed. Returning");
118      }
119      return;
120    }
121
122    String accessControlRequestMethod =
123        req.getHeader(ACCESS_CONTROL_REQUEST_METHOD);
124    if (!isMethodAllowed(accessControlRequestMethod)) {
125      if(LOG.isDebugEnabled()) {
126        LOG.debug("Access control method '" + accessControlRequestMethod +
127            "' not allowed. Returning");
128      }
129      return;
130    }
131
132    String accessControlRequestHeaders =
133        req.getHeader(ACCESS_CONTROL_REQUEST_HEADERS);
134    if (!areHeadersAllowed(accessControlRequestHeaders)) {
135      if(LOG.isDebugEnabled()) {
136        LOG.debug("Access control headers '" + accessControlRequestHeaders +
137            "' not allowed. Returning");
138      }
139      return;
140    }
141
142    if(LOG.isDebugEnabled()) {
143      LOG.debug("Completed cross origin filter checks. Populating " +
144          "HttpServletResponse");
145    }
146    res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, originsList);
147    res.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, Boolean.TRUE.toString());
148    res.setHeader(ACCESS_CONTROL_ALLOW_METHODS, getAllowedMethodsHeader());
149    res.setHeader(ACCESS_CONTROL_ALLOW_HEADERS, getAllowedHeadersHeader());
150    res.setHeader(ACCESS_CONTROL_MAX_AGE, maxAge);
151  }
152
153  @VisibleForTesting
154  String getAllowedHeadersHeader() {
155    return StringUtils.join(allowedHeaders, ',');
156  }
157
158  @VisibleForTesting
159  String getAllowedMethodsHeader() {
160    return StringUtils.join(allowedMethods, ',');
161  }
162
163  private void initializeAllowedMethods(FilterConfig filterConfig) {
164    String allowedMethodsConfig =
165        filterConfig.getInitParameter(ALLOWED_METHODS);
166    if (allowedMethodsConfig == null) {
167      allowedMethodsConfig = ALLOWED_METHODS_DEFAULT;
168    }
169    allowedMethods.addAll(
170        Arrays.asList(allowedMethodsConfig.trim().split("\\s*,\\s*")));
171    LOG.info("Allowed Methods: " + getAllowedMethodsHeader());
172  }
173
174  private void initializeAllowedHeaders(FilterConfig filterConfig) {
175    String allowedHeadersConfig =
176        filterConfig.getInitParameter(ALLOWED_HEADERS);
177    if (allowedHeadersConfig == null) {
178      allowedHeadersConfig = ALLOWED_HEADERS_DEFAULT;
179    }
180    allowedHeaders.addAll(
181        Arrays.asList(allowedHeadersConfig.trim().split("\\s*,\\s*")));
182    LOG.info("Allowed Headers: " + getAllowedHeadersHeader());
183  }
184
185  private void initializeAllowedOrigins(FilterConfig filterConfig) {
186    String allowedOriginsConfig =
187        filterConfig.getInitParameter(ALLOWED_ORIGINS);
188    if (allowedOriginsConfig == null) {
189      allowedOriginsConfig = ALLOWED_ORIGINS_DEFAULT;
190    }
191    allowedOrigins.addAll(
192        Arrays.asList(allowedOriginsConfig.trim().split("\\s*,\\s*")));
193    allowAllOrigins = allowedOrigins.contains("*");
194    LOG.info("Allowed Origins: " + StringUtils.join(allowedOrigins, ','));
195    LOG.info("Allow All Origins: " + allowAllOrigins);
196  }
197
198  private void initializeMaxAge(FilterConfig filterConfig) {
199    maxAge = filterConfig.getInitParameter(MAX_AGE);
200    if (maxAge == null) {
201      maxAge = MAX_AGE_DEFAULT;
202    }
203    LOG.info("Max Age: " + maxAge);
204  }
205
206  static String encodeHeader(final String header) {
207    if (header == null) {
208      return null;
209    }
210    // Protect against HTTP response splitting vulnerability
211    // since value is written as part of the response header
212    // Ensure this header only has one header by removing
213    // CRs and LFs
214    return header.split("\n|\r")[0].trim();
215  }
216
217  static boolean isCrossOrigin(String originsList) {
218    return originsList != null;
219  }
220
221  @VisibleForTesting
222  boolean areOriginsAllowed(String originsList) {
223    if (allowAllOrigins) {
224      return true;
225    }
226
227    String[] origins = originsList.trim().split("\\s+");
228    for (String origin : origins) {
229      for (String allowedOrigin : allowedOrigins) {
230        if (allowedOrigin.contains("*")) {
231          String regex = allowedOrigin.replace(".", "\\.").replace("*", ".*");
232          Pattern p = Pattern.compile(regex);
233          Matcher m = p.matcher(origin);
234          if (m.matches()) {
235            return true;
236          }
237        } else if (allowedOrigin.equals(origin)) {
238          return true;
239        }
240      }
241    }
242    return false;
243  }
244
245  private boolean areHeadersAllowed(String accessControlRequestHeaders) {
246    if (accessControlRequestHeaders == null) {
247      return true;
248    }
249    String[] headers = accessControlRequestHeaders.trim().split("\\s*,\\s*");
250    return allowedHeaders.containsAll(Arrays.asList(headers));
251  }
252
253  private boolean isMethodAllowed(String accessControlRequestMethod) {
254    if (accessControlRequestMethod == null) {
255      return true;
256    }
257    return allowedMethods.contains(accessControlRequestMethod);
258  }
259}