package gov.va.med.ars.configuration.spring;

import java.util.regex.Pattern;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class XSSRequestWrapper extends HttpServletRequestWrapper {

	private static final Logger log = LoggerFactory.getLogger(XSSRequestWrapper.class);
	
	 private static Pattern[] patterns = new Pattern[]{
		        // Script fragments
		        Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE),
		        // src='...'
		        Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
		        Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
		        // lonely script tags
		        Pattern.compile("</script>", Pattern.CASE_INSENSITIVE),
		        Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
		        // eval(...)
		        Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
		        // expression(...)
		        Pattern.compile("expression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
		        // javascript:...
		        Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE),
		        // vbscript:...
		        Pattern.compile("vbscript:", Pattern.CASE_INSENSITIVE),
		        // onload(...)=...
		        Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL)
		    };

		    public XSSRequestWrapper(HttpServletRequest servletRequest) {
		        super(servletRequest);
		    }

		    @Override
		    public String[] getParameterValues(String parameter) {
		        String[] values = super.getParameterValues(parameter);

		        if (values == null) {
		            return null;
		        }

		        int count = values.length;
		        String[] encodedValues = new String[count];
		        for (int i = 0; i < count; i++) {
		            encodedValues[i] = stripXSS(values[i]);
		        }

		        return encodedValues;
		    }

		    @Override
		    public String getParameter(String parameter) {
		        String value = super.getParameter(parameter);

		        return stripXSS(value);
		    }

		    @Override
		    public String getHeader(String name) {
		        String value = super.getHeader(name);
		        return stripXSS(value);
		    }

		    private String stripXSS(String value) {
		        if (value != null) {
		            // NOTE: It's highly recommended to use the ESAPI library and uncomment the following line to
		            // avoid encoded attacks.
		            // value = ESAPI.encoder().canonicalize(value);

		            // Avoid null characters
		            value = value.replaceAll("", "");

		            // Remove all sections that match a pattern
		            for (Pattern scriptPattern : patterns){
		                value = scriptPattern.matcher(value).replaceAll("");
		            }
		        }
		        return value;
		    }
			

/*    private static Pattern[] patterns = new Pattern[]{
        // Script fragments
        Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE),
        // Script fragments
        Pattern.compile(".*<script>(.*?)script>.*", Pattern.CASE_INSENSITIVE),

        // src='...'
        Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
        Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
        // lonely script tags
        Pattern.compile("</script>", Pattern.CASE_INSENSITIVE),
        Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
        // eval(...)
        Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
        // expression(...)
        Pattern.compile("expression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
        // javascript:...
        Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE),
        // vbscript:...
        Pattern.compile("vbscript:", Pattern.CASE_INSENSITIVE),
        // onload(...)=...
        Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL)
    };

    private String _body;

    public XSSRequestWrapper(HttpServletRequest servletRequest) throws IOException, ServletException {
        super(servletRequest);

        _body = "";
        try (BufferedReader bufferedReader = servletRequest.getReader())
        {
            String line;
            while ((line = bufferedReader.readLine()) != null)
                _body += line;
        }
        catch(Exception e){
        	log.error("Excpetion in reading input"+e);
        }


        if(_body != null && servletRequest != null && servletRequest.getContentType() != null &&
        		( servletRequest.getContentType().equalsIgnoreCase(ContentType.APPLICATION_JSON.toString())
        				|| servletRequest.getContentType().equalsIgnoreCase(ContentType.APPLICATION_JSON.toString().replaceAll("\\s","")))){

        	JSONObject jObj = new JSONObject(_body);
	        Iterator<String> it = jObj.keys(); //gets all the keys
	        while(it.hasNext())
	        {
	            String key = (String) it.next(); // get key
	            Object obj = jObj.get(key); // get value
				if(!isValidParam(obj.toString())) {
					throw new ServletException("Unallowed parameter detected: ");
				}
	        }

        }
        String contentType = servletRequest.getContentType();

        String s = servletRequest.getParameter("obj");
        if(s != null){
	        JSONObject jObj = new JSONObject(servletRequest.getParameter("obj")); // this parses the json
	        Iterator it = jObj.keys(); //gets all the keys
	        while(it.hasNext())
	        {
	            String key = (String) it.next(); // get key
	            Object o = jObj.get(key); // get value
	        }
        }

        Enumeration<String> params = servletRequest.getParameterNames();
		if(params != null && params.hasMoreElements()){
			while(params.hasMoreElements()){
				String paramName = params.nextElement();
				String paramValue = servletRequest.getParameter(paramName);

			}

		}
    }



    @Override
    public ServletInputStream getInputStream() throws IOException
    {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(_body.getBytes());
        return new ServletInputStream()
        {
            public int read() throws IOException
            {
                return byteArrayInputStream.read();
            }

			@Override
			public boolean isFinished() {
				// TODO Auto-generated method stub
				return false;
			}

			@Override
			public boolean isReady() {
				// TODO Auto-generated method stub
				return false;
			}

			@Override
			public void setReadListener(ReadListener listener) {
				// TODO Auto-generated method stub

			}
        };
    }

    @Override
    public BufferedReader getReader() throws IOException
    {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    @Override
    public String[] getParameterValues(String parameter) {
        String[] values = super.getParameterValues(parameter);

        if (values == null) {
            return null;
        }

        int count = values.length;
        String[] encodedValues = new String[count];
        for (int i = 0; i < count; i++) {
            encodedValues[i] = stripXSS(values[i]);
        }

        return encodedValues;
    }

    @Override
    public String getParameter(String parameter) {
        String value = super.getParameter(parameter);

        return stripXSS(value);
    }

    @Override
    public Object getAttribute(String name) {
    	Object value = super.getAttribute(name);

    	return stripXSS(value);
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        return stripXSS(value);
    }

	@Override
	public String getParameter(String name) {
		String parameter = null;
		String[] vals = getParameterMap().get(name);

		if (vals != null && vals.length > 0) {
			parameter = vals[0];
		}

		return parameter;
	}

	@Override
	public String[] getParameterValues(String name) {
		return getParameterMap().get(name);
	}

	@Override
	public Enumeration<String> getParameterNames() {
		return Collections.enumeration(getParameterMap().keySet());
	}

	private Map<String, String[]> sanitizedQueryString;

	@Override
	public Map<String,String[]> getParameterMap() {
		if(sanitizedQueryString == null) {
			Map<String, String[]> res = new HashMap<String, String[]>();
			Map<String, String[]> originalQueryString = super.getParameterMap();
			if(originalQueryString!=null) {
				for (String key : (Set<String>) originalQueryString.keySet()) {
					String[] rawVals = originalQueryString.get(key);
					String[] snzVals = new String[rawVals.length];
					for (int i=0; i < rawVals.length; i++) {
						snzVals[i] = stripXSS(rawVals[i]);
						log.debug("Sanitized: " + rawVals[i] + " to " + snzVals[i]);
					}
					res.put(stripXSS(key), snzVals);
				}
			}
			sanitizedQueryString = res;
		}
		return sanitizedQueryString;
	}

    private String stripXSS(String value) {
        if (value != null) {
            // NOTE: It's highly recommended to use the ESAPI library and uncomment the following line to
            // avoid encoded attacks.
            // value = ESAPI.encoder().canonicalize(value);

            // Avoid null characters
            value = value.replaceAll("", "");

            // Remove all sections that match a pattern
            for (Pattern scriptPattern : patterns){
                value = scriptPattern.matcher(value).replaceAll("");
            }
        }
        return value;
    }

    private boolean isValidParam(String value) {
        if (value != null) {
            // NOTE: It's highly recommended to use the ESAPI library and uncomment the following line to
            // avoid encoded attacks.
            // value = ESAPI.encoder().canonicalize(value);

            // Avoid null characters
            value = value.replaceAll("", "");

            // Remove all sections that match a pattern
            for (Pattern scriptPattern : patterns){
                if(scriptPattern.matcher(value).matches()){
                	return false;
                }
            }
        }
        return true;
    }*/
}