package com.agilex.healthcare.mobilehealthplatform.security;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;
import java.util.StringTokenizer;

import javax.annotation.Resource;
import javax.sql.DataSource;

import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.support.SqlLobValue;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.JdbcTokenStore;
import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;

import com.agilex.healthcare.mobilehealthplatform.domain.MhpUser;
import com.agilex.healthcare.mobilehealthplatform.oauth.OauthJdbcClientDetailsService;
import com.agilex.security.oauth.TokenStoreLogout;

public class MhpJdbcTokenStore extends JdbcTokenStore implements TokenStoreLogout {

    private static final String USER_ACCESS_TOKEN_INSERT_STATEMENT = "insert into access_token_user (user_name, access_token) values (?, ?)";
    private static final String USER_ACCESS_TOKEN_SELECT_STATEMENT = "select access_token from access_token_user where user_name = ?";
    private static final String USER_ACCESS_TOKEN_DELETE_STATEMENT = "delete from access_token_user where user_name = ?";
    private static final String USER_ACCESS_TOKEN_DELETE_EXCEPT_LONG_DURABLE_TOKEN_STATEMENT = "delete from access_token_user where user_name = ? and access_token not in (?)";
    private static final String USER_ACCESS_TOKEN_DELETE_LONG_DURABLE_TOKEN_STATEMENT = "delete from access_token_user where user_name = ? and access_token in (?)";
    private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_access_token where authentication_id = ?";
    private static final String DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT = "insert into oauth_access_token (token_id, token, authentication_id, user_name, client_id, authentication, refresh_token) values (?, ?, ?, ?, ?, ?, ?)";

    private static final org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory.getLog(MhpJdbcTokenStore.class);

    private final JdbcTemplate jdbcTemplate;

    private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();

    @Resource
    private ResourceServerTokenServices tokenServices;
 
    @Resource
    private OauthJdbcClientDetailsService oauthJdbcClientDetailsService;
    
    @Resource
    private String longDurableClientIds;
    
    public MhpJdbcTokenStore(DataSource dataSource) {
    	super(dataSource);
        this.jdbcTemplate = new JdbcTemplate(dataSource);
    }
    
    public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
	this.authenticationKeyGenerator = authenticationKeyGenerator;
    }
    
    @Override
    public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
        AppUser appUser = (AppUser)authentication.getUserAuthentication().getPrincipal();
        addToken(appUser.getMhpUser().getUserIdentifier().toString(), token.getValue());
        String refreshToken = null;
        if (token.getRefreshToken() != null) {
                refreshToken = token.getRefreshToken().getValue();
        }

        jdbcTemplate.update(DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT, new Object[] { extractTokenKey(token.getValue()),
                        new SqlLobValue(serializeAccessToken(token)), authenticationKeyGenerator.extractKey(authentication),
                        authentication.isClientOnly() ? null : getCurrentUserId(authentication),
                        authentication.getAuthorizationRequest().getClientId(),
                        new SqlLobValue(serializeAuthentication(authentication)), extractTokenKey(refreshToken) }, new int[] {
                        Types.VARCHAR, Types.BLOB, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR, Types.BLOB, Types.VARCHAR });
    }

        @Override
	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
		OAuth2AccessToken accessToken = null;

		String key = authenticationKeyGenerator.extractKey(authentication);
		try {
			accessToken = jdbcTemplate.queryForObject(DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT,
					new RowMapper<OAuth2AccessToken>() {
						public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
							return deserializeAccessToken(rs.getBytes(2));
						}
					}, key);
		}
		catch (EmptyResultDataAccessException e) {
			if (logger.isInfoEnabled()) {
				logger.debug("Failed to find access token for authentication");
			}
		}
		catch (IllegalArgumentException e) {
			logger.error("Could not extract access token for authentication");
		}

		if (accessToken != null) {
			removeAccessToken(accessToken.getValue());
			// Keep the store consistent (maybe the same user is represented by this authentication but the details have
			// changed)
			storeAccessToken(accessToken, authentication);
		}
		return accessToken;
	}
    
    private void addToken(String userName, String accessToken){
        jdbcTemplate.update(
                USER_ACCESS_TOKEN_INSERT_STATEMENT,
                new Object[]{userName, accessToken}, new int[]{
                Types.VARCHAR, Types.VARCHAR});
    }

    public void removeAllUSerTokens(Authentication authentication) {
        String userName = getCurrentUserId(((OAuth2Authentication)authentication).getUserAuthentication());
        logger.info("Removing access tokens");

        List<String> accessTokens = fetchUserAccessTokens(userName);
        
        if (accessTokens != null) {
            StringTokenizer st = new StringTokenizer(longDurableClientIds, ",");
            
            while (st.hasMoreElements()) {
                String longDurableClientId = st.nextElement().toString();
                if (!longDurableClientId.equalsIgnoreCase(((OAuth2Authentication)authentication).getAuthorizationRequest().getClientId())) {
                    doNotDeleteLongDurableToken(authentication, accessTokens, userName);    
                } else if (longDurableClientId.equalsIgnoreCase(((OAuth2Authentication)authentication).getAuthorizationRequest().getClientId())) {
                    deleteOnlyLongDurableToken(authentication, accessTokens, userName);
                }
            }            
        }
    }

    private void doNotDeleteLongDurableToken(Authentication authentication, List<String> accessTokens, String userName) {
    	String p2pToken = null;
    	for (String accessToken : accessTokens) {
    	    OAuth2Authentication oauth2Authentication = loadAuthenticationForGivenToken(accessToken);
    	    if (oauth2Authentication != null) {
                if (!longDurableClientIds.equalsIgnoreCase(oauth2Authentication.getAuthorizationRequest().getClientId())) {
                    super.removeAccessToken(accessToken);
                } else {
                    p2pToken = accessToken;
                }
    	    } else {
    	        super.removeAccessToken(accessToken);
    	    }
    	}
    	if (p2pToken != null) {
    	    donotDeleteLongDurableToken(userName, p2pToken);
    	} else {
    	    deleteUserTokens(userName);
    	}
    }

    private void deleteOnlyLongDurableToken(Authentication authentication, List<String> accessTokens, String userName) {
        String longDurableToken = null;
        for (String accessToken : accessTokens) {
            OAuth2Authentication oauth2Authentication = loadAuthenticationForGivenToken(accessToken);
            if (oauth2Authentication != null) {
                if (longDurableClientIds.equalsIgnoreCase(oauth2Authentication.getAuthorizationRequest().getClientId())) {
                    longDurableToken = accessToken;
                    super.removeAccessToken(accessToken);
                }
            }
        }
        if (longDurableToken != null) {
            deleteLongDurableToken(userName, longDurableToken);
        }
    }
    
    private OAuth2Authentication loadAuthenticationForGivenToken(String accessToken) {
            OAuth2Authentication authentication = null;
		
            try {
                authentication = tokenServices.loadAuthentication(accessToken);
            } catch (AuthenticationException authenticationException) {
                logger.info("Either token was expired or not available", authenticationException);
            } catch (InvalidTokenException invalidTokenException) {
                logger.info("Either token was expired or not available", invalidTokenException);
            }
            return authentication;
	}
    
    private String getCurrentUserId(Authentication authentication){
        AppUser appUser = (AppUser) authentication.getPrincipal();
        MhpUser mhpUser = appUser.getMhpUser();
        return mhpUser.getUserIdentifier().toString();
    }

    private List<String> fetchUserAccessTokens(String userName) {
        return this.jdbcTemplate.query(USER_ACCESS_TOKEN_SELECT_STATEMENT, new Object[]{userName},
                new RowMapper<String>() {
                    public String mapRow(ResultSet rs, int rowNum) throws SQLException {
                        return rs.getString("access_token");
                    }
                });
    }

    private void deleteUserTokens(String userName) {
        jdbcTemplate.update(USER_ACCESS_TOKEN_DELETE_STATEMENT, userName);
    }
    
    private void donotDeleteLongDurableToken(String userName, String accessToken) {
        jdbcTemplate.update(USER_ACCESS_TOKEN_DELETE_EXCEPT_LONG_DURABLE_TOKEN_STATEMENT, 
        		new Object[]{userName, accessToken}, new int[]{Types.VARCHAR, Types.VARCHAR});
    }

    private void deleteLongDurableToken(String userName, String longDurableToken) {
        jdbcTemplate.update(USER_ACCESS_TOKEN_DELETE_LONG_DURABLE_TOKEN_STATEMENT, 
                new Object[]{userName, longDurableToken}, new int[]{Types.VARCHAR, Types.VARCHAR});
    }
    
}
