package gov.va.med.cds.ws.saml.interceptor;

import org.apache.cxf.binding.soap.SoapMessage;
import org.apache.cxf.headers.Header;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.ws.security.wss4j.SamlTokenInterceptor;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.wss4j.dom.WSConstants;
import org.w3c.dom.Element;

import gov.va.med.cds.saml.SamlAssertionThreadLocal;
import gov.va.med.cds.saml.SamlAssertionXml;
import gov.va.med.cds.ws.saml.util.SamlUtility;



public class SAMLTokenInterceptor extends SamlTokenInterceptor { 
	
    public SAMLTokenInterceptor() {
        super();
    }
 
   
	@Override
	protected void processToken(SoapMessage message) {
	
		//returns the wsse:security header - and will NOT create one if it does not exist
		Header h = findSecurityHeader(message, false);
        if (h == null) {
            return;
        }
        Element el = (Element)h.getObject();
        Element child = DOMUtils.getFirstElement(el);
        if(child != null) {
            if ("Assertion".equals(child.getLocalName()) && (WSConstants.SAML_NS.equals(child.getNamespaceURI()) || WSConstants.SAML2_NS.equals(child.getNamespaceURI()))) {
            	try {
						String samlXML = SamlUtility.assertionToString(child);
						SamlAssertionThreadLocal.set(new SamlAssertionXml(samlXML));
								
				} catch (WSSecurityException e) {
					// TODO Auto-generated catch block //add exception handling ... but not sure how to do that ...
					e.printStackTrace();
				}
            }
        }
		
	}

}


