Summary Table

Categories Total Count
PII 0
URL 0
DNS 0
EKL 0
IP 0
PORT 0
VsID 0
CF 0
AI 0
VPD 0
PL 0
Other 0

File Content

package gov.va.med.ars.configuration.spring;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;

import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.apache.commons.lang.StringUtils;
import org.apache.http.entity.ContentType;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class XSSRequestWrapper extends HttpServletRequestWrapper {

private static final Logger log = LoggerFactory.getLogger(XSSRequestWrapper.class);

private static Pattern[] patterns = new Pattern[]{
// Script fragments
Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE),
// Script fragments
Pattern.compile(".*<script>(.*?)script>.*", Pattern.CASE_INSENSITIVE),

// src='...'
Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
// lonely script tags
Pattern.compile("</script>", Pattern.CASE_INSENSITIVE),
Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
// eval(...)
Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
// expression(...)
Pattern.compile("expression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL),
// javascript:...
Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE),
// vbscript:...
Pattern.compile("vbscript:", Pattern.CASE_INSENSITIVE),
// onload(...)=...
Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL)
};

private String _body;

public XSSRequestWrapper(HttpServletRequest servletRequest) throws IOException, ServletException {
super(servletRequest);

_body = "";
try (BufferedReader bufferedReader = servletRequest.getReader())
{
String line;
while ((line = bufferedReader.readLine()) != null)
_body += line;
}
catch(Exception e){
log.error("Excpetion in reading input"+e);
}

String requestContentType = servletRequest.getContentType();
String contentType = ContentType.APPLICATION_JSON.toString();
String contentTypeNoSpace = ContentType.APPLICATION_JSON.toString().replaceAll("\\s","");

if(StringUtils.isNotBlank(_body) && servletRequest != null && servletRequest.getContentType() != null &&
( contentType.matches("(?i).*"+requestContentType+".*")
|| contentTypeNoSpace.matches("(?i).*"+requestContentType+".*"))){

System.out.println("body is: "+_body);
JSONObject jObj = new JSONObject(_body);
Iterator<String> it = jObj.keys(); //gets all the keys
while(it.hasNext())
{
String key = (String) it.next(); // get key
Object obj = jObj.get(key); // get value
if(!isValidParam(obj.toString())) {
throw new ServletException("Unallowed parameter detected: ");
}
}

}
/* String contentType = servletRequest.getContentType();

String s = servletRequest.getParameter("obj");
if(s != null){
JSONObject jObj = new JSONObject(servletRequest.getParameter("obj")); // this parses the json
Iterator it = jObj.keys(); //gets all the keys
while(it.hasNext())
{
String key = (String) it.next(); // get key
Object o = jObj.get(key); // get value
}
}

Enumeration<String> params = servletRequest.getParameterNames();
if(params != null && params.hasMoreElements()){
while(params.hasMoreElements()){
String paramName = params.nextElement();
String paramValue = servletRequest.getParameter(paramName);

}

}*/
}



@Override
public ServletInputStream getInputStream() throws IOException
{
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(_body.getBytes());
return new ServletInputStream()
{
public int read() throws IOException
{
return byteArrayInputStream.read();
}

@Override
public boolean isFinished() {
// TODO Auto-generated method stub
return false;
}

@Override
public boolean isReady() {
// TODO Auto-generated method stub
return false;
}

@Override
public void setReadListener(ReadListener listener) {
// TODO Auto-generated method stub

}
};
}

@Override
public BufferedReader getReader() throws IOException
{
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}

/* @Override
public String[] getParameterValues(String parameter) {
String[] values = super.getParameterValues(parameter);

if (values == null) {
return null;
}

int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = stripXSS(values[i]);
}

return encodedValues;
}*/

@Override
public String getParameter(String parameter) {
String value = super.getParameter(parameter);

return stripXSS(value);
}

/* @Override
public Object getAttribute(String name) {
Object value = super.getAttribute(name);

return stripXSS(value);
}*/

@Override
public String getHeader(String name) {
String value = super.getHeader(name);
return stripXSS(value);
}

/* @Override
public String getParameter(String name) {
String parameter = null;
String[] vals = getParameterMap().get(name);

if (vals != null && vals.length > 0) {
parameter = vals[0];
}

return parameter;
}*/

@Override
public String[] getParameterValues(String name) {
return getParameterMap().get(name);
}

@Override
public Enumeration<String> getParameterNames() {
return Collections.enumeration(getParameterMap().keySet());
}

private Map<String, String[]> sanitizedQueryString;

@Override
public Map<String,String[]> getParameterMap() {
if(sanitizedQueryString == null) {
Map<String, String[]> res = new HashMap<String, String[]>();
Map<String, String[]> originalQueryString = super.getParameterMap();
if(originalQueryString!=null) {
for (String key : (Set<String>) originalQueryString.keySet()) {
String[] rawVals = originalQueryString.get(key);
String[] snzVals = new String[rawVals.length];
for (int i=0; i < rawVals.length; i++) {
snzVals[i] = stripXSS(rawVals[i]);
log.debug("Sanitized: " + rawVals[i] + " to " + snzVals[i]);
}
res.put(stripXSS(key), snzVals);
}
}
sanitizedQueryString = res;
}
return sanitizedQueryString;
}

private String stripXSS(String value) {
if (value != null) {
// NOTE: It's highly recommended to use the ESAPI library and uncomment the following line to
// avoid encoded attacks.
// value = ESAPI.encoder().canonicalize(value);

// Avoid null characters
value = value.replaceAll("", "");

// Remove all sections that match a pattern
for (Pattern scriptPattern : patterns){
value = scriptPattern.matcher(value).replaceAll("");
}
}
return value;
}

private boolean isValidParam(String value) {
if (value != null) {
// NOTE: It's highly recommended to use the ESAPI library and uncomment the following line to
// avoid encoded attacks.
// value = ESAPI.encoder().canonicalize(value);

// Avoid null characters
value = value.replaceAll("", "");

// Remove all sections that match a pattern
for (Pattern scriptPattern : patterns){
if(scriptPattern.matcher(value).matches()){
return false;
}
}
}
return true;
}
}