001/** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 */ 018 019package org.apache.hadoop.security.http; 020 021import java.io.IOException; 022import java.util.ArrayList; 023import java.util.Arrays; 024import java.util.List; 025import java.util.regex.Matcher; 026import java.util.regex.Pattern; 027 028import javax.servlet.Filter; 029import javax.servlet.FilterChain; 030import javax.servlet.FilterConfig; 031import javax.servlet.ServletException; 032import javax.servlet.ServletRequest; 033import javax.servlet.ServletResponse; 034import javax.servlet.http.HttpServletRequest; 035import javax.servlet.http.HttpServletResponse; 036 037import org.apache.commons.lang.StringUtils; 038import org.apache.commons.logging.Log; 039import org.apache.commons.logging.LogFactory; 040 041import com.google.common.annotations.VisibleForTesting; 042 043public class CrossOriginFilter implements Filter { 044 045 private static final Log LOG = LogFactory.getLog(CrossOriginFilter.class); 046 047 // HTTP CORS Request Headers 048 static final String ORIGIN = "Origin"; 049 static final String ACCESS_CONTROL_REQUEST_METHOD = 050 "Access-Control-Request-Method"; 051 static final String ACCESS_CONTROL_REQUEST_HEADERS = 052 "Access-Control-Request-Headers"; 053 054 // HTTP CORS Response Headers 055 static final String ACCESS_CONTROL_ALLOW_ORIGIN = 056 "Access-Control-Allow-Origin"; 057 static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = 058 "Access-Control-Allow-Credentials"; 059 static final String ACCESS_CONTROL_ALLOW_METHODS = 060 "Access-Control-Allow-Methods"; 061 static final String ACCESS_CONTROL_ALLOW_HEADERS = 062 "Access-Control-Allow-Headers"; 063 static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"; 064 065 // Filter configuration 066 public static final String ALLOWED_ORIGINS = "allowed-origins"; 067 public static final String ALLOWED_ORIGINS_DEFAULT = "*"; 068 public static final String ALLOWED_METHODS = "allowed-methods"; 069 public static final String ALLOWED_METHODS_DEFAULT = "GET,POST,HEAD"; 070 public static final String ALLOWED_HEADERS = "allowed-headers"; 071 public static final String ALLOWED_HEADERS_DEFAULT = 072 "X-Requested-With,Content-Type,Accept,Origin"; 073 public static final String MAX_AGE = "max-age"; 074 public static final String MAX_AGE_DEFAULT = "1800"; 075 076 private List<String> allowedMethods = new ArrayList<String>(); 077 private List<String> allowedHeaders = new ArrayList<String>(); 078 private List<String> allowedOrigins = new ArrayList<String>(); 079 private boolean allowAllOrigins = true; 080 private String maxAge; 081 082 @Override 083 public void init(FilterConfig filterConfig) throws ServletException { 084 initializeAllowedMethods(filterConfig); 085 initializeAllowedHeaders(filterConfig); 086 initializeAllowedOrigins(filterConfig); 087 initializeMaxAge(filterConfig); 088 } 089 090 @Override 091 public void doFilter(ServletRequest req, ServletResponse res, 092 FilterChain chain) 093 throws IOException, ServletException { 094 doCrossFilter((HttpServletRequest) req, (HttpServletResponse) res); 095 chain.doFilter(req, res); 096 } 097 098 @Override 099 public void destroy() { 100 allowedMethods.clear(); 101 allowedHeaders.clear(); 102 allowedOrigins.clear(); 103 } 104 105 private void doCrossFilter(HttpServletRequest req, HttpServletResponse res) { 106 107 String originsList = encodeHeader(req.getHeader(ORIGIN)); 108 if (!isCrossOrigin(originsList)) { 109 if(LOG.isDebugEnabled()) { 110 LOG.debug("Header origin is null. Returning"); 111 } 112 return; 113 } 114 115 if (!areOriginsAllowed(originsList)) { 116 if(LOG.isDebugEnabled()) { 117 LOG.debug("Header origins '" + originsList + "' not allowed. Returning"); 118 } 119 return; 120 } 121 122 String accessControlRequestMethod = 123 req.getHeader(ACCESS_CONTROL_REQUEST_METHOD); 124 if (!isMethodAllowed(accessControlRequestMethod)) { 125 if(LOG.isDebugEnabled()) { 126 LOG.debug("Access control method '" + accessControlRequestMethod + 127 "' not allowed. Returning"); 128 } 129 return; 130 } 131 132 String accessControlRequestHeaders = 133 req.getHeader(ACCESS_CONTROL_REQUEST_HEADERS); 134 if (!areHeadersAllowed(accessControlRequestHeaders)) { 135 if(LOG.isDebugEnabled()) { 136 LOG.debug("Access control headers '" + accessControlRequestHeaders + 137 "' not allowed. Returning"); 138 } 139 return; 140 } 141 142 if(LOG.isDebugEnabled()) { 143 LOG.debug("Completed cross origin filter checks. Populating " + 144 "HttpServletResponse"); 145 } 146 res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, originsList); 147 res.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, Boolean.TRUE.toString()); 148 res.setHeader(ACCESS_CONTROL_ALLOW_METHODS, getAllowedMethodsHeader()); 149 res.setHeader(ACCESS_CONTROL_ALLOW_HEADERS, getAllowedHeadersHeader()); 150 res.setHeader(ACCESS_CONTROL_MAX_AGE, maxAge); 151 } 152 153 @VisibleForTesting 154 String getAllowedHeadersHeader() { 155 return StringUtils.join(allowedHeaders, ','); 156 } 157 158 @VisibleForTesting 159 String getAllowedMethodsHeader() { 160 return StringUtils.join(allowedMethods, ','); 161 } 162 163 private void initializeAllowedMethods(FilterConfig filterConfig) { 164 String allowedMethodsConfig = 165 filterConfig.getInitParameter(ALLOWED_METHODS); 166 if (allowedMethodsConfig == null) { 167 allowedMethodsConfig = ALLOWED_METHODS_DEFAULT; 168 } 169 allowedMethods.addAll( 170 Arrays.asList(allowedMethodsConfig.trim().split("\\s*,\\s*"))); 171 LOG.info("Allowed Methods: " + getAllowedMethodsHeader()); 172 } 173 174 private void initializeAllowedHeaders(FilterConfig filterConfig) { 175 String allowedHeadersConfig = 176 filterConfig.getInitParameter(ALLOWED_HEADERS); 177 if (allowedHeadersConfig == null) { 178 allowedHeadersConfig = ALLOWED_HEADERS_DEFAULT; 179 } 180 allowedHeaders.addAll( 181 Arrays.asList(allowedHeadersConfig.trim().split("\\s*,\\s*"))); 182 LOG.info("Allowed Headers: " + getAllowedHeadersHeader()); 183 } 184 185 private void initializeAllowedOrigins(FilterConfig filterConfig) { 186 String allowedOriginsConfig = 187 filterConfig.getInitParameter(ALLOWED_ORIGINS); 188 if (allowedOriginsConfig == null) { 189 allowedOriginsConfig = ALLOWED_ORIGINS_DEFAULT; 190 } 191 allowedOrigins.addAll( 192 Arrays.asList(allowedOriginsConfig.trim().split("\\s*,\\s*"))); 193 allowAllOrigins = allowedOrigins.contains("*"); 194 LOG.info("Allowed Origins: " + StringUtils.join(allowedOrigins, ',')); 195 LOG.info("Allow All Origins: " + allowAllOrigins); 196 } 197 198 private void initializeMaxAge(FilterConfig filterConfig) { 199 maxAge = filterConfig.getInitParameter(MAX_AGE); 200 if (maxAge == null) { 201 maxAge = MAX_AGE_DEFAULT; 202 } 203 LOG.info("Max Age: " + maxAge); 204 } 205 206 static String encodeHeader(final String header) { 207 if (header == null) { 208 return null; 209 } 210 // Protect against HTTP response splitting vulnerability 211 // since value is written as part of the response header 212 // Ensure this header only has one header by removing 213 // CRs and LFs 214 return header.split("\n|\r")[0].trim(); 215 } 216 217 static boolean isCrossOrigin(String originsList) { 218 return originsList != null; 219 } 220 221 @VisibleForTesting 222 boolean areOriginsAllowed(String originsList) { 223 if (allowAllOrigins) { 224 return true; 225 } 226 227 String[] origins = originsList.trim().split("\\s+"); 228 for (String origin : origins) { 229 for (String allowedOrigin : allowedOrigins) { 230 if (allowedOrigin.contains("*")) { 231 String regex = allowedOrigin.replace(".", "\\.").replace("*", ".*"); 232 Pattern p = Pattern.compile(regex); 233 Matcher m = p.matcher(origin); 234 if (m.matches()) { 235 return true; 236 } 237 } else if (allowedOrigin.equals(origin)) { 238 return true; 239 } 240 } 241 } 242 return false; 243 } 244 245 private boolean areHeadersAllowed(String accessControlRequestHeaders) { 246 if (accessControlRequestHeaders == null) { 247 return true; 248 } 249 String[] headers = accessControlRequestHeaders.trim().split("\\s*,\\s*"); 250 return allowedHeaders.containsAll(Arrays.asList(headers)); 251 } 252 253 private boolean isMethodAllowed(String accessControlRequestMethod) { 254 if (accessControlRequestMethod == null) { 255 return true; 256 } 257 return allowedMethods.contains(accessControlRequestMethod); 258 } 259}