

package gov.va.med.cds.ars.saml;


import java.util.Set;

import javax.xml.namespace.QName;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPHeader;
import javax.xml.ws.ProtocolException;
import javax.xml.ws.handler.MessageContext;
import javax.xml.ws.handler.soap.SOAPMessageContext;

import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.wss4j.common.util.DOM2Writer;
import org.w3c.dom.Element;
import org.w3c.dom.Node;

import gov.va.med.cds.saml.SamlAssertionThreadLocal;
import gov.va.med.cds.saml.SamlAssertionXml;


/**
 * SAML Soap IN Interceptor class used by JAX-WS service (NON CXF Services) to
 * extract SAML Assertions from the Security Header of the Soap Envelope
 * 
 * @author DNS   TALBOM
 *
 */
public class SoapSamlHandler
    implements
        javax.xml.ws.handler.soap.SOAPHandler<SOAPMessageContext>
{

    public static final String SAML_NS = "urn:oasis:names:tc:SAML:1.0:assertion";
    public static final String SAML2_NS = "urn:oasis:names:tc:SAML:2.0:assertion";


    @Override
    public Set<QName> getHeaders( )
    {
        return null;
    }


    @Override
    public void close( MessageContext mc )
    {
    }


    @Override
    public boolean handleFault( SOAPMessageContext mc )
    {
        return true;
    }


    @Override
    public boolean handleMessage( SOAPMessageContext mc )
    {

        if ( Boolean.FALSE.equals( mc.get( MessageContext.MESSAGE_OUTBOUND_PROPERTY ) ) )
        {
            try
            {

                SOAPHeader sh = mc.getMessage().getSOAPHeader();
                Node securityElement = sh.getFirstChild();
                if ( securityElement != null )
                {
                    Element child = getFirstElement( securityElement );
                    if ( child != null )
                    {
                        if ( "Assertion".equals( child.getLocalName() )
                                && ( SAML_NS.equals( child.getNamespaceURI() ) || SAML2_NS.equals( child.getNamespaceURI() ) ) )
                        {
                            String samlXML = assertionToString( child );
                            SamlAssertionThreadLocal.set( new SamlAssertionXml( samlXML ) );
                        }
                    }
                }

            }
            catch ( SOAPException se )
            {
                throw new ProtocolException( se );
            }
            catch ( Exception e )
            {
                throw new ProtocolException( e );
            }
        }
        
        return true;
    }


    protected static Element getFirstElement( Node parent )
    {
        Node n = parent.getFirstChild();
        while ( ( n != null ) && ( Node.ELEMENT_NODE != n.getNodeType() ) )
        {
            n = n.getNextSibling();
        }
        
        if ( n == null )
        {
            return null;
        }
        
        return ( Element )n;
    }


    protected static String assertionToString( Element assertionElement ) throws WSSecurityException
    {
        String assertion = "";
        if ( assertionElement != null )
        {
            assertion = DOM2Writer.nodeToString( assertionElement );
        }
        
        return assertion;
    }

}
