package gov.va.med.fw.ui.filter;

// Java imports
import gov.va.med.fw.ui.DelegatingActionUtils;
import gov.va.med.fw.ui.security.UIEncryptionService;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.web.context.WebApplicationContext;

/**
 * The wrapper overrides the getParameterNames(), getParameterValues(), getInputStream(), getContentLength(), getParameter() and getHeader() methods to execute 
 * the filtering the documented attacks before returning the desired field to the caller.
 * 
 * Documented Attacks includes the ClassLoader Manipulation attacks and XSS attacks. 
 * Methods are updated to protect ClassLoader Manipulation attacks are
 * 1. getParameterNames
 * 2. getInputStream
 * 3. getContentLength
 * 4. getHeader
 * 5. getHeaders
 * 6. getHeaderNames
 * 7. getParameterMap
 * 
 * Methods are updated to protect XSS attacks are
 * 1. getParameter
 * 2. getParameterValues
 * 3. getHeader
 * 4. getHeaders
 * 5. getParameterMap
 * 
 * Here user entered unvalidated input is validated as per blacklist & content length patterns and XSS patterns.
 * 
 * @version 1.0
 */
public class XSSRequestWrapper extends HttpServletRequestWrapper {
	
	protected Log logger = LogFactory.getLog( getClass() );
	private String errorURL = null;
	private HttpServletResponse servletResponse = null;
	protected WebSecurityFilter webSecurityFilter;
	
	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
	private static final int BUFFER_SIZE = 128;
	private static final String DEFAULT_BLACKLIST_PATTERN = "(.*\\.|^|.*|\\[('|\"))(c|C)lass(\\.|('|\")]|\\[).*";
	private static final String CONTENT_LENGTH_PATTERN = "(?i)content-length";
	
	private Pattern pattern = null; 
	private Pattern content_length_pattern = null;
	private String body = null;
	private boolean read_stream = false;
	/*** ends ***/
	
	public WebSecurityFilter getWebSecurityFilter() {
		return webSecurityFilter;
	}
	
	public void setWebSecurityFilter(WebSecurityFilter webSecurityFilter) {
		this.webSecurityFilter = webSecurityFilter;
	}
	
	public XSSRequestWrapper(HttpServletRequest servletRequest, HttpServletResponse servletResponse, String errorURL) {
		super(servletRequest);
		this.servletResponse = servletResponse;
		this.errorURL = errorURL;
		
		/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/		
		if (pattern == null) {
			pattern = Pattern.compile(DEFAULT_BLACKLIST_PATTERN, Pattern.DOTALL);
		}
		if (content_length_pattern == null) {
			content_length_pattern = Pattern.compile(CONTENT_LENGTH_PATTERN, Pattern.DOTALL);
		}
		/*** ends ***/

		try {
			
			HttpServletRequest req = (HttpServletRequest) servletRequest;
			if (webSecurityFilter == null) {
				WebApplicationContext ac = 
					DelegatingActionUtils.findRequiredWebApplicationContext( req.getSession().getServletContext() );
				
				webSecurityFilter = (WebSecurityFilter)ac.getBean( "webSecurityFilter" );
				setErrorPage(webSecurityFilter.getSsoErrorUrl());
			}			

		}
		catch (Exception e) {
			logger.error( "Failed to obtain an webSecurityFilter", e );
		}
		
		if (logger.isDebugEnabled()) {
			logger.debug("XSSRequestWrapper : errorURL is " + errorURL);
		}
		
		/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
        StringBuilder stringBuilder = new StringBuilder(); 
        BufferedReader bufferedReader = null; 


        try { 
            InputStream inputStream = servletRequest.getInputStream(); 

            if (inputStream != null) { 
                String characterEncoding = this.getCharacterEncoding(); 
                if (characterEncoding == null ) { 
                    bufferedReader = new BufferedReader(new InputStreamReader(inputStream)); 
                } 
                else { 
                    bufferedReader = new BufferedReader(new InputStreamReader(inputStream, characterEncoding)); 
                } 
                char[] charBuffer = new char[BUFFER_SIZE]; 
                int bytesRead = -1; 

                while ( (bytesRead = bufferedReader.read(charBuffer)) > 0 ) { 
                    stringBuilder.append(charBuffer, 0, bytesRead); 
                } 
            } else { 
                stringBuilder.append(""); 
            } 
        } catch ( IOException ex ) { 
            logger.error( "Error occured in XSSRequestWrapper ", ex );
        } finally { 
            if (bufferedReader != null) { 
                try { 
                    bufferedReader.close(); 
                } catch ( IOException ex ) { 
                	logger.error( "Error occured in XSSRequestWrapper ", ex );
                } 
            } 
        } 
        body = stringBuilder.toString(); 
		/*** ends ***/
	}

	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
	@Override 
    public Enumeration getParameterNames() { 
        List finalParameterNames = new ArrayList(); 
        List parameterNames = Collections.list((Enumeration) super.getParameterNames()); 
        final Iterator iterator = parameterNames.iterator(); 
        while ( iterator.hasNext() ) { 
            String parameterName = (String) iterator.next(); 
            if (pattern != null && !pattern.matcher(parameterName).matches()) { 
                finalParameterNames.add(parameterName); 
            } 
        } 
        return Collections.enumeration(finalParameterNames); 
    } 
	/*** ends ***/
	
	@Override
	public String[] getParameterValues(String parameter) {
		String[] values = super.getParameterValues(parameter);
		if (values == null) {
			return null;
		}

		int count = values.length;
		String[] encodedValues = new String[count];
		for (int i = 0; i < count; i++) {
			encodedValues[i] = stripXSS(values[i]);
		}
		return encodedValues;
	}
	
	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks and strip XSS values ***/
	@Override 
    public Map getParameterMap() { 
		Map parameterMap = super.getParameterMap();		
		Map tmpMap = new Hashtable();

    	if(parameterMap != null && !parameterMap.isEmpty())
    	{
    		for(Iterator iter = parameterMap.keySet().iterator(); iter.hasNext();)
    		{
    			String name = (String)iter.next();
    			String[] values = (String[])parameterMap.get(name);
    			boolean isParameterNameSafe = false;
    			if (pattern != null && !pattern.matcher(name).matches()) { 
    				isParameterNameSafe = true;
                } 
    			if(isParameterNameSafe && values != null)
    			{
    				int count = values.length;
    				String[] encodedValues = new String[count];
    				for (int i = 0; i < count; i++) {
    					encodedValues[i] = stripXSS(values[i]);
    				}
    				tmpMap.put(name, encodedValues);
    			}
    			
    		}
    		parameterMap = tmpMap;
    	}
    	if(parameterMap == null) {
    		parameterMap = new Hashtable();
    	}
    	parameterMap = Collections.unmodifiableMap(parameterMap);
    	
    	return parameterMap;
    } 
	/*** ends ***/	

	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
    @Override 
    public ServletInputStream getInputStream() throws IOException { 
        //if (logger.isTraceEnabled()) { 
        //	logger.trace(body); 
        //} 
        final ByteArrayInputStream byteArrayInputStream; 
        if (pattern.matcher(body).matches()) { 
            if (logger.isWarnEnabled()) { 
            	logger.warn("[getInputStream]: found body to match blacklisted parameter pattern"); 
            } 
            byteArrayInputStream = new ByteArrayInputStream("".getBytes()); 
        } else if (read_stream) { 
            byteArrayInputStream = new ByteArrayInputStream("".getBytes()); 
        } else { 
            if (logger.isDebugEnabled()) { 
            	logger.debug("[getInputStream]: OK - body does not match blacklisted parameter pattern"); 
            } 
            byteArrayInputStream = new ByteArrayInputStream(body.getBytes()); 
            read_stream = true; 
        } 


        return new ServletInputStream() { 
            public int read() throws IOException { 
                return byteArrayInputStream.read(); 
            } 
        }; 
    } 
    /*** ends ***/
    
	@Override
	public String getParameter(String parameter) {
		String value = super.getParameter(parameter);
		return stripXSS(value);
	}

	@Override
	public String getHeader(String name) {
		/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
        if (pattern != null && body != null && pattern.matcher(body).matches() && content_length_pattern.matcher(name).matches()) { 
            return "0"; 
        } 
        /*** ends ***/
        
		String value = super.getHeader(name);        
		return stripXSS(value);
	}

	@Override
	public Enumeration getHeaders(String name) {
		List finalHeaderValues = new ArrayList(); 
		
		/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
        if (pattern != null && body != null && pattern.matcher(body).matches() && content_length_pattern.matcher(name).matches()) { 
        	finalHeaderValues.add("0"); 
            return Collections.enumeration(finalHeaderValues); 
        } 
        /*** ends ***/
              
		List values = Collections.list((Enumeration) super.getHeaders(name)); 
		final Iterator iterator = values.iterator(); 
        while ( iterator.hasNext() ) { 
            String value = (String) iterator.next(); 
            finalHeaderValues.add(stripXSS(value)); 
        } 
        return Collections.enumeration(finalHeaderValues); 
	}	
	
	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
	@Override 
    public Enumeration getHeaderNames() { 
        List finalHeaderNames = new ArrayList(); 
        List headerNames = Collections.list((Enumeration) super.getHeaderNames()); 
        final Iterator iterator = headerNames.iterator(); 
        while ( iterator.hasNext() ) { 
            String headerName = (String) iterator.next(); 
            if (pattern != null && !pattern.matcher(headerName).matches()) { 
            	finalHeaderNames.add(headerName); 
            } 
        } 
        return Collections.enumeration(finalHeaderNames); 
    } 
	/*** ends ***/
	
	/*** Protect the Struts 1 Servlet from ClassLoader Manipulation attacks ***/
    @Override 
    public int getContentLength() { 
        if (pattern != null && body != null && pattern.matcher(body).matches()) { 
            return 0; 
        } 
        return super.getContentLength(); 
    }
    /*** ends ***/
    
	/**
	 * The actual XSS checking and striping is performed in the stripXSS() private method.
	 * 
	 * @param value
	 * @return value after filtering out XSS
	 */
	private String stripXSS(String value) {
		if (value != null) {
			
			String orgValue = value;
			int orgValueLength = value.length();
			
			// Avoid null characters
			value = value.replaceAll("", "");
			
			// Avoid anything between script tags
			Pattern scriptPattern = Pattern.compile("<script>(.*?)</script>",
					Pattern.CASE_INSENSITIVE);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid anything in a src='...' type of expression
			scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Remove any lonesome </script> tag
			scriptPattern = Pattern.compile("</script>",
					Pattern.CASE_INSENSITIVE);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Remove any lonesome <script ...> tag
			scriptPattern = Pattern.compile("<script(.*?)>",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid eval(...) expressions
			scriptPattern = Pattern.compile("eval\\((.*?)\\)",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid expression(...) expressions
			scriptPattern = Pattern.compile("expression\\((.*?)\\)",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid javascript:... expressions
			scriptPattern = Pattern.compile("javascript:",
					Pattern.CASE_INSENSITIVE);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid vbscript:... expressions
			scriptPattern = Pattern.compile("vbscript:",
					Pattern.CASE_INSENSITIVE);
			value = scriptPattern.matcher(value).replaceAll("");
			
			// Avoid onload= expressions
			scriptPattern = Pattern.compile("onload(.*?)=",
					Pattern.CASE_INSENSITIVE | Pattern.MULTILINE
							| Pattern.DOTALL);
			value = scriptPattern.matcher(value).replaceAll("");
			if (value.length() != orgValueLength) {
				logger.warn("Possible XSS Attack received");// - original value is "+orgValue+" !!!");
				//logger.warn("Value is stripped to "+value);
				try {
					if (getErrorPage() != null) {
						this.servletResponse.sendRedirect( getErrorPage() );
					}
				} catch(IOException ioe) {
					logger.warn("Error occured when redirecting to Security Error Page is "+ioe.getMessage());
				}
			}
		}
		return value;
	}

	public String getErrorPage() {
		if (logger.isDebugEnabled()) {
			logger.debug("getErrorPage errorURL is " + errorURL);
		}
		return errorURL;
	}
	public void setErrorPage(String errorURL) {
		if (logger.isDebugEnabled()) {
			logger.debug("setErrorPage errorURL is " + errorURL);
		}
		this.errorURL = errorURL;
	}	
}
