package com.agilex.security.oauth;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.List;

import javax.sql.DataSource;

import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.JdbcTokenStore;

public class IamJdbcTokenStore extends JdbcTokenStore implements TokenStoreLogout {

    private static final String USER_ACCESS_TOKEN_INSERT_STATEMENT = "insert into access_token_user (user_name, access_token) values (?, ?)";
    private static final String USER_ACCESS_TOKEN_SELECT_STATEMENT = "select access_token from access_token_user where user_name = ?";
    private static final String USER_ACCESS_TOKEN_DELETE_STATEMENT = "delete from access_token_user where user_name = ?";
    private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_access_token where authentication_id = ?";

    private static final org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory.getLog(IamJdbcTokenStore.class);

    private final JdbcTemplate jdbcTemplate;

	private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();

    public IamJdbcTokenStore(DataSource dataSource) {
    	super(dataSource);
        this.jdbcTemplate = new JdbcTemplate(dataSource);
    }

	public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
		this.authenticationKeyGenerator = authenticationKeyGenerator;
	}
	
    @Override
    public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
        IamAppUser userIam = (IamAppUser)authentication.getUserAuthentication().getPrincipal();
        addToken(userIam.getIamUser().getId(), token.getValue());
        super.storeAccessToken(token, authentication);
    }

    @Override
	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
		OAuth2AccessToken accessToken = null;

		String key = authenticationKeyGenerator.extractKey(authentication);
		try {
			accessToken = jdbcTemplate.queryForObject(DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT,
					new RowMapper<OAuth2AccessToken>() {
						public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
							return deserializeAccessToken(rs.getBytes(2));
						}
					}, key);
		}
		catch (EmptyResultDataAccessException e) {
			if (logger.isInfoEnabled()) {
				logger.debug("Failed to find access token for authentication");
			}
		}
		catch (IllegalArgumentException e) {
			logger.error("Could not extract access token for authentication");
		}

		if (accessToken != null) {
			removeAccessToken(accessToken.getValue());
			// Keep the store consistent (maybe the same user is represented by this authentication but the details have
			// changed)
			storeAccessToken(accessToken, authentication);
		}
		return accessToken;
	}
    
    private void addToken(String userName, String accessToken){
        jdbcTemplate.update(
                USER_ACCESS_TOKEN_INSERT_STATEMENT,
                new Object[]{userName, accessToken}, new int[]{
                Types.VARCHAR, Types.VARCHAR});
    }

    public void removeAllUSerTokens(Authentication authentication) {
        String userName = getCurrentUserId(authentication);
        logger.info("Removing access tokens");

        List<String> accessTokens = fetchUserAccessTokens(userName);

        if (accessTokens != null) {
            for (String accessToken : accessTokens) {
                super.removeAccessToken(accessToken);
            }
            deleteUserTokens(userName);
        }
    }

    private String getCurrentUserId(Authentication authentication){
        IamAppUser iamAppUser = (IamAppUser) authentication.getPrincipal();
        IamUser iamUser = iamAppUser.getIamUser();
        return iamUser.getId();
    }

    private List<String> fetchUserAccessTokens(String userName) {
        return this.jdbcTemplate.query(USER_ACCESS_TOKEN_SELECT_STATEMENT, new Object[]{userName},
                new RowMapper<String>() {
                    public String mapRow(ResultSet rs, int rowNum) throws SQLException {
                        return rs.getString("access_token");
                    }
                });
    }

    private void deleteUserTokens(String userName) {
        jdbcTemplate.update(USER_ACCESS_TOKEN_DELETE_STATEMENT, userName);
    }
}
