/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package gov.va.med.nhin.adapter.logging.messaging.fhir;

import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.parser.IParser;
import gov.hhs.fha.nhinc.common.nhinccommon.AssertionType;
import gov.hhs.fha.nhinc.common.nhinccommon.CeType;
import gov.va.med.nhin.adapter.datamanager.DataQuery;
import gov.va.med.nhin.adapter.logging.EventAuditingFactory;
import gov.va.med.nhin.adapter.logging.LogConstants;
import gov.va.med.nhin.adapter.logging.LogConstants.AuditingEvent;
import gov.va.med.nhin.adapter.logging.MessagingHelper;
import static gov.va.med.nhin.adapter.logging.MessagingHelper.REQUEST;
import gov.va.med.nhin.adapter.utils.AuditUtil;
import gov.va.med.nhin.adapter.utils.NullChecker;
import java.io.StringWriter;
import java.util.Arrays;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.ws.handler.MessageContext;
import javax.xml.ws.handler.soap.SOAPMessageContext;
import java.util.Date;
import javax.servlet.http.HttpServletRequest;
import javax.xml.bind.JAXBElement;
import org.hl7.fhir.dstu3.model.AuditEvent;
import org.hl7.fhir.dstu3.model.AuditEvent.AuditEventAgentComponent;
import org.hl7.fhir.dstu3.model.AuditEvent.AuditEventAgentNetworkComponent;
import org.hl7.fhir.dstu3.model.AuditEvent.AuditEventEntityComponent;
import org.hl7.fhir.dstu3.model.AuditEvent.AuditEventEntityDetailComponent;
import org.hl7.fhir.dstu3.model.AuditEvent.AuditEventSourceComponent;
import org.hl7.fhir.dstu3.model.Base64BinaryType;
import org.hl7.fhir.dstu3.model.CodeableConcept;
import org.hl7.fhir.dstu3.model.Coding;
import org.hl7.fhir.dstu3.model.Identifier;
import org.hl7.fhir.dstu3.model.InstantType;
import org.hl7.fhir.dstu3.model.Reference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

/**
 * Class to create AuditEvents and log them using slf4j
 *
 * @author ryan
 */
public class FHIREventAuditingFactory implements EventAuditingFactory<AuditEvent> {

  private final FhirContext ctx = FhirContext.forDstu3();
  private final IParser parser = ctx.newJsonParser();

  private static void setDurationEntity(AuditEvent event) {
    Date start = event.getRecorded();
    Date now = new Date();
    long duration = now.getTime() - start.getTime();

    if (duration > 1) {
      AuditEventEntityComponent cmp = event.addEntity();
      cmp.setType(LogConstants.SYSTEM_OBJ);
      cmp.setRole(LogConstants.RESPONSE_TIME);

      cmp.setDescription(String.valueOf(duration));
    }
  }

  @Override
  public void trace(AuditEvent event) {
    Logger log = LoggerFactory.getLogger(getCategory(event));
    setDurationEntity(event);
    log.trace(parser.encodeResourceToString(event));
  }

  @Override
  public void debug(AuditEvent event) {
    Logger log = LoggerFactory.getLogger(getCategory(event));
    setDurationEntity(event);
    log.debug(parser.encodeResourceToString(event));
  }

  @Override
  public void info(AuditEvent event) {
    String category = getCategory(event);
    Logger log = LoggerFactory.getLogger(category);
    setDurationEntity(event);
    String str = parser.encodeResourceToString(event);
    log.info(str);
  }

  @Override
  public void warn(AuditEvent event) {
    Logger log = LoggerFactory.getLogger(getCategory(event));
    setDurationEntity(event);
    log.warn(parser.encodeResourceToString(event));
  }

  @Override
  public void error(AuditEvent event) {
    Logger log = LoggerFactory.getLogger(getCategory(event));
    setDurationEntity(event);
    log.error(parser.encodeResourceToString(event));
  }

  public String getCategory(AuditEvent ae) {
    String code = ae.getSubtypeFirstRep().getCode();
    return AuditingEvent.fromCode(code).category;
  }

  @Override
  public void trace(AuditingEvent c, Class<?> klass) {
    trace(newEvent(c, klass));
  }

  @Override
  public void debug(AuditingEvent c, Class<?> klass) {
    debug(newEvent(c, klass));
  }

  @Override
  public void info(AuditingEvent c, Class<?> klass) {
    info(newEvent(c, klass));
  }

  @Override
  public void warn(AuditingEvent c, Class<?> klass) {
    warn(newEvent(c, klass));
  }

  @Override
  public void error(AuditingEvent c, Class<?> klass) {
    error(newEvent(c, klass));
  }

  @Override
  public AuditEvent newEvent(AuditingEvent c, Class<?> source) {
    return newEvent(c, true, source);
  }

  @Override
  public AuditEvent newEvent(AuditingEvent c, boolean includeNetworkIfAvail,
          Class<?> source) {
    AuditEventSourceComponent srccmp = new AuditEventSourceComponent();
    srccmp.setSite(source.getCanonicalName());
    AuditEvent ae = new AuditEvent(c.type, new InstantType(new Date()), srccmp);
    ae.addSubtype(c.coding);
    ae.setAction(c.action);

    String txn = MDC.get(LogConstants.CORRELATION_MDCKEY);
    if (null != txn) {
      AuditEventEntityComponent aeec = ae.addEntity();
      aeec.setType(LogConstants.SYSTEM_OBJ);
      aeec.setRole(LogConstants.CORRELATION_CODING);

      Identifier id = new Identifier();
      id.setValue(txn);
      id.setType(LogConstants.CORRELATION_CONCEPT);
      aeec.setIdentifier(id);
    }

    if (includeNetworkIfAvail) {
      String remaddr = MDC.get(LogConstants.REMOTE_ADDR);
      if (null != remaddr) {
        AuditEventAgentComponent aeac = ae.addAgent();
        aeac.setId("sourceagent");
        aeac.addRole(LogConstants.SOURCE_CONCEPT);
        aeac.setRequestor(false);

        AuditEventAgentNetworkComponent net = new AuditEventAgentNetworkComponent();
        net.setAddress(remaddr);
        aeac.setNetwork(net);
      }

      String laddr = MDC.get(LogConstants.LOCAL_ADDR);
      String endpoint = MDC.get(LogConstants.ENDPOINT);
      if (!(null == endpoint && null == laddr)) {
        AuditEventAgentComponent aeac = ae.addAgent();
        aeac.setRequestor(false);
        aeac.addRole(LogConstants.DEST_CONCEPT);

        if (null != endpoint) {
          Identifier id = new Identifier();
          id.setValue(endpoint);
          aeac.setUserId(id);
        }

        if (null != laddr) {
          AuditEventAgentNetworkComponent net = new AuditEventAgentNetworkComponent();
          net.setAddress(laddr);
          aeac.setNetwork(net);
        }
      }
    }

    return ae;
  }

  @Override
  public AuditEvent newEvent(SOAPMessageContext msgctx, Class<?> source) {
    Object obj = msgctx.get(MessageContext.SERVLET_REQUEST);
    HttpServletRequest req = HttpServletRequest.class.cast(obj);

    String endpointAddress = null;
    for (String prop : Arrays.asList("javax.xml.ws.service.endpoint.address",
            "com.sun.xml.ws.transport.http.servlet.requestURL")) {
      Object o = msgctx.get(prop);
      if (NullChecker.isNotNullOrEmpty(o)) {
        endpointAddress = String.class.cast(o);
        break;
      }
    }

    if (null != req) {
      MDC.put(LogConstants.REMOTE_ADDR, req.getRemoteAddr());
      MDC.put(LogConstants.ENDPOINT, endpointAddress);
      MDC.put(LogConstants.LOCAL_ADDR, req.getLocalAddr());
    }

    return newEvent(AuditingEvent.INFO, true, source);
  }

  private static CodeableConcept toCc(CeType type) {
    CodeableConcept cc = new CodeableConcept();
    Coding coding = cc.addCoding();
    if (null != type.getCode()) {
      coding.setCode(type.getCode());
    }
    if (null != type.getCodeSystem()) {
      coding.setSystem(type.getCodeSystem());
    }
    if (null != type.getDisplayName()) {
      coding.setDisplay(type.getDisplayName());
    }
    if (null != type.getCodeSystemVersion()) {
      coding.setVersion(type.getCodeSystemVersion());
    }

    return cc;
  }

  @Override
  public AuditEvent newEvent(AssertionType assertion, AuditingEvent code, Class<?> source) {
    AuditEvent ae = newEvent(code, source);

    for (AuditEventAgentComponent aeac : ae.getAgent()) {
      // set the system id for our sourceagent element (we can't get the
      // system id from assertions at the time we create the agent)
      if ("sourceagent".equals(aeac.getId())) {
        aeac.setAltId(AuditUtil.checkSystemId(assertion));
      }
    }

    AuditEventAgentComponent agent1 = ae.addAgent();
    if (!(NullChecker.isNullOrEmpty(assertion)
            || NullChecker.isNotNullOrEmpty(assertion.getHomeCommunity()))) {

      Reference userhcid = new Reference();
      Identifier hcid = new Identifier();
      hcid.setValue(assertion.getHomeCommunity().getHomeCommunityId());
      userhcid.setIdentifier(hcid);
      agent1.setLocation(userhcid);

      agent1.setRequestor(true);
      if (null != assertion.getUserInfo()) {
        if (null != assertion.getUserInfo().getUserName()) {
          Identifier id = new Identifier();
          id.setValue(assertion.getUserInfo().getUserName());
        }

        if (null != assertion.getUserInfo().getRoleCoded()) {
          agent1.addRole(toCc(assertion.getUserInfo().getRoleCoded()));
        }
      }
    }

    if (null != assertion.getPurposeOfDisclosureCoded()) {
      agent1.setPurposeOfUse(Arrays.asList(toCc(assertion.getPurposeOfDisclosureCoded())));
    }

    for (String patientid : assertion.getUniquePatientId()) {
      AuditEventEntityComponent entity0 = ae.addEntity();
      entity0.setType(LogConstants.PERSON);
      entity0.setRole(LogConstants.PATIENT);

      Identifier e0id = new Identifier();

      e0id.setUse(Identifier.IdentifierUse.USUAL);
      e0id.setValue(patientid);
      entity0.setIdentifier(e0id);

      if (null != assertion.getUserInfo()) {
        if (!(null == assertion.getUserInfo()
                || null == assertion.getUserInfo().getOrg())) {
          e0id.setSystem(assertion.getUserInfo().getOrg().getHomeCommunityId());
        }
      }
    }

    if (NullChecker.isNotNullOrEmpty(assertion.getSSN())) {
      AuditEventEntityComponent entity1 = ae.addEntity();
      entity1.setType(LogConstants.PERSON);
      entity1.setRole(LogConstants.PATIENT);

      Identifier e1id = new Identifier();
      e1id.setValue(assertion.getSSN());
      entity1.setIdentifier(e1id);
      e1id.setSystem("http://hl7.org/fhir/sid/us-ssn");
      e1id.setUse(Identifier.IdentifierUse.SECONDARY);
    }

    if (!(null == assertion.getHomeCommunity()
            || null == assertion.getHomeCommunity().getHomeCommunityId())) {
      AuditEventEntityComponent entity2 = ae.addEntity();
      entity2.setType(LogConstants.SYSTEM_OBJ);
      entity2.setRole(LogConstants.QUERY);
      AuditEventEntityDetailComponent cmp = entity2.addDetail();
      cmp.setType("ihe:homeCommunityId");
      cmp.setValueElement(new Base64BinaryType(assertion.getHomeCommunity().getHomeCommunityId()));
    }
    return ae;
  }

  @Override
  public AuditEvent addAgent(AuditEvent event, String systemuri, String codename,
          String data) {
    AuditEventAgentComponent aeac = event.addAgent();
    aeac.setName(data);
    CodeableConcept cc = aeac.addRole();
    Coding c = cc.addCoding();
    c.setSystem(systemuri);
    c.setCode(codename);
    return event;
  }

  @Override
  public AuditEvent addEntity(AuditEvent event, String systemuri, String codename,
          String data) {
    AuditEventEntityComponent aeec = event.addEntity();
    aeec.setType(new Coding(systemuri, codename, null));
    aeec.setName(data);
    return event;
  }

  @Override
  public MessagingHelper<AuditEvent> messaging() {
    return new MessagingHelper<AuditEvent>() {
      @Override
      public AuditEvent partnerauth(AuditingEvent evt, Class source, String partner) {
        return addEntity(newEvent(evt, source), LogConstants.MESSAGING_URI,
                PARTNER, partner);
      }

      @Override
      public AuditEvent reqres(AuditingEvent evt, Class source, String partner) {
        return addAgent(newEvent(evt, source), LogConstants.MESSAGING_URI,
                REQUEST, partner);
      }

      @Override
      public AuditEvent addPatientId(AuditEvent evt, String pid) {
        return addEntity(evt, LogConstants.MESSAGING_URI, PATIENT, pid);
      }

      @Override
      public AuditEvent addLocalRemotePatientIds(AuditEvent event, String local, String remote) {
        AuditEvent e = addEntity(event, LogConstants.MESSAGING_URI, LOCALPATIENT,
                local);
        return addEntity(e, LogConstants.MESSAGING_URI, REMOTEPATIENT, remote);
      }

      @Override
      public AuditEvent addAgentFacility(AuditEvent event, String facility) {
        return addAgent(event, LogConstants.MESSAGING_URI, FACILITY, facility);
      }

      @Override
      public AuditEvent addDetail(AuditEvent ae, String id, String type, String value) {
        if (null != id) {
          for (AuditEventEntityComponent aeec : ae.getEntity()) {
            String aeecid = aeec.getId();

            if (id.equals(aeecid)) {
              AuditEventEntityDetailComponent ad = aeec.addDetail();
              ad.setType(type);
              ad.setValueElement(new Base64BinaryType(value));
            }
          }
        }

        return ae;
      }

      @Override
      public AuditEvent addQuery(AuditEvent ae, DataQuery obj, QueryType t) {
        AuditEventEntityComponent aeec = addQueryComponent(ae, t);

        for (String param : obj.getParameterNames()) {
          AuditEventEntityDetailComponent ad = aeec.addDetail();
          ad.setType(param);
          ad.setValueElement(new Base64BinaryType(String.valueOf(obj.getParameter(param))));
        }
        return ae;
      }

      private AuditEventEntityComponent addQueryComponent(AuditEvent ae, QueryType t) {
        AuditEventEntityComponent aeec = ae.addEntity();
        aeec.setType(LogConstants.SYSTEM_OBJ);
        aeec.setRole(QueryType.REQUEST == t
                ? LogConstants.QUERY
                : LogConstants.QUERY_RESPONSE);
        return aeec;
      }

      @Override
      public AuditEvent addQuery(AuditEvent ae, Object obj, Class<?> klass,
              QueryType t, String id) {
        AuditEventEntityComponent aeec = addQueryComponent(ae, t);
        aeec.setId(id);

        if (obj instanceof String) {
          Base64BinaryType q = new Base64BinaryType(String.class.cast(obj).getBytes());
          aeec.setQueryElement(q);
        } else {
          Logger log = LoggerFactory.getLogger(getClass());
          JAXBContext ctx = null;
          Marshaller marshaller = null;
          try {
            ctx = JAXBContext.newInstance(obj.getClass());
            marshaller = ctx.createMarshaller();
          } catch (JAXBException e) {
            log.error("unable to create context/marshaller", e);
            return ae;
          }

          StringWriter sw = new StringWriter();
          try {
            // If obj doesn't have the @XmlRootElement annotation, this
            // call will throw an exception
            marshaller.marshal(obj, sw);
          } catch (JAXBException e) {
            // Could be any exception, but check the @XmlRootElement common case
            try {
              JAXBElement ele
                      = new JAXBElement(LogConstants.ROOTLESS_JAXB, klass, obj);
              marshaller.marshal(ele, sw);
            } catch (JAXBException x) {
              log.error("unable to marshal query", e);
            }
          }

          Base64BinaryType q = new Base64BinaryType(sw.toString().getBytes());
          aeec.setQueryElement(q);
        }

        return ae;
      }
    };
  }
}
