/*
 * Decompiled with CFR 0.152.
 */
package oidc.secure;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.crypto.tink.Aead;
import com.google.crypto.tink.CleartextKeysetHandle;
import com.google.crypto.tink.JsonKeysetReader;
import com.google.crypto.tink.JsonKeysetWriter;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.KeysetReader;
import com.google.crypto.tink.aead.AeadConfig;
import com.google.crypto.tink.aead.AeadFactory;
import com.google.crypto.tink.aead.AeadKeyTemplates;
import com.google.crypto.tink.proto.KeyTemplate;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.openid.connect.sdk.Nonce;
import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
import com.nimbusds.openid.connect.sdk.claims.CodeHash;
import com.nimbusds.openid.connect.sdk.claims.StateHash;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.Provider;
import java.security.SecureRandom;
import java.security.Security;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.stream.Collectors;
import oidc.endpoints.MapTypeReference;
import oidc.exceptions.InvalidSignatureException;
import oidc.model.OpenIDClient;
import oidc.model.SigningKey;
import oidc.model.SymmetricKey;
import oidc.model.User;
import oidc.repository.SequenceRepository;
import oidc.repository.SigningKeyRepository;
import oidc.repository.SymmetricKeyRepository;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.event.ApplicationStartedEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.core.env.Environment;
import org.springframework.core.env.Profiles;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

@Component
public class TokenGenerator
implements MapTypeReference,
ApplicationListener<ApplicationStartedEvent> {
    public static final JWSAlgorithm signingAlg = JWSAlgorithm.RS256;
    public static final Instant instant = Instant.parse("2100-01-01T00:00:00.00Z");
    private static char[] DEFAULT_CODEC = "1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz".toCharArray();
    private Random random = new SecureRandom();
    private String issuer;
    private Map<String, JWSSigner> signers;
    private Map<String, JWSVerifier> verifiers;
    private String currentSigningKeyId;
    private List<RSAKey> publicKeys;
    private byte[] associatedData;
    private KeysetHandle primaryKeysetHandle;
    private Map<String, KeysetHandle> keysetHandleMap;
    private String currentSymmetricKeyId;
    private ObjectMapper objectMapper;
    private Clock clock;
    private SigningKeyRepository signingKeyRepository;
    private SymmetricKeyRepository symmetricKeyRepository;
    private SequenceRepository sequenceRepository;
    private List<String> acrValuesSupported;
    private String defaultAcrValue;

    @Autowired
    public TokenGenerator(@Value(value="${spring.security.saml2.service-provider.entity-id}") String issuer, @Value(value="${secret_key_set_path}") Resource secretKeySetPath, @Value(value="${associated_data}") String associatedData, @Value(value="${openid_configuration_path}") Resource configurationPath, @Value(value="${default_acr_value}") String defaultAcrValue, ObjectMapper objectMapper, SigningKeyRepository signingKeyRepository, SequenceRepository sequenceRepository, SymmetricKeyRepository symmetricKeyRepository, Environment environment) throws IOException, GeneralSecurityException {
        Security.addProvider((Provider)new BouncyCastleProvider());
        AeadConfig.register();
        this.signingKeyRepository = signingKeyRepository;
        this.sequenceRepository = sequenceRepository;
        this.symmetricKeyRepository = symmetricKeyRepository;
        this.issuer = issuer;
        this.objectMapper = objectMapper;
        this.clock = environment.acceptsProfiles(Profiles.of((String[])new String[]{"dev"})) ? Clock.fixed(instant, ZoneId.systemDefault()) : Clock.systemDefaultZone();
        this.primaryKeysetHandle = CleartextKeysetHandle.read((KeysetReader)JsonKeysetReader.withInputStream((InputStream)secretKeySetPath.getInputStream()));
        this.associatedData = associatedData.getBytes(Charset.defaultCharset());
        Map wellKnownConfiguration = (Map)objectMapper.readValue(configurationPath.getInputStream(), mapTypeReference);
        this.acrValuesSupported = (List)wellKnownConfiguration.get("acr_values_supported");
        this.defaultAcrValue = defaultAcrValue;
    }

    public void onApplicationEvent(ApplicationStartedEvent event) {
        this.initializeSymmetricKeys();
        this.initializeSigningKeys();
    }

    private void initializeSigningKeys() {
        List<Object> rsaKeys = this.signingKeyRepository.findAllByOrderByCreatedDesc().stream().filter(signingKey -> StringUtils.hasText((String)signingKey.getSymmetricKeyId())).map(arg_0 -> this.parseEncryptedRsaKey(arg_0)).collect(Collectors.toList());
        if (rsaKeys.isEmpty()) {
            SigningKey signingKey2 = this.generateEncryptedRsaKey();
            this.signingKeyRepository.save((Object)signingKey2);
            RSAKey rsaKey = this.parseEncryptedRsaKey(signingKey2);
            rsaKeys = Collections.singletonList(rsaKey);
        }
        this.publicKeys = rsaKeys.stream().map(RSAKey::toPublicJWK).collect(Collectors.toList());
        this.currentSigningKeyId = ((RSAKey)rsaKeys.get(0)).getKeyID();
        this.signers = rsaKeys.stream().collect(Collectors.toMap(JWK::getKeyID, arg_0 -> this.createRSASigner(arg_0)));
        this.verifiers = rsaKeys.stream().collect(Collectors.toMap(JWK::getKeyID, arg_0 -> this.createRSAVerifier(arg_0)));
    }

    public SigningKey rolloverSigningKeys() throws NoSuchProviderException, NoSuchAlgorithmException {
        SigningKey signingKey = this.generateEncryptedRsaKey();
        this.signingKeyRepository.save((Object)signingKey);
        this.initializeSigningKeys();
        return signingKey;
    }

    private void initializeSymmetricKeys() {
        List<Object> keysetHandles = this.symmetricKeyRepository.findAllByOrderByCreatedDesc().stream().map(arg_0 -> this.parseKeysetHandle(arg_0)).collect(Collectors.toList());
        if (keysetHandles.isEmpty()) {
            this.signingKeyRepository.deleteAll();
            SymmetricKey symmetricKey = this.generateSymmetricKey();
            keysetHandles = Collections.singletonList(this.parseKeysetHandle(symmetricKey));
        }
        this.currentSymmetricKeyId = String.valueOf(((KeysetHandle)keysetHandles.get(0)).getKeysetInfo().getPrimaryKeyId());
        this.keysetHandleMap = keysetHandles.stream().collect(Collectors.toMap(keysetHandle -> String.valueOf(keysetHandle.getKeysetInfo().getPrimaryKeyId()), keysetHandle -> keysetHandle));
    }

    private SymmetricKey generateSymmetricKey() {
        try {
            KeysetHandle keysetHandle = KeysetHandle.generateNew((KeyTemplate)AeadKeyTemplates.AES256_CTR_HMAC_SHA256);
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            keysetHandle.write(JsonKeysetWriter.withOutputStream((OutputStream)outputStream), AeadFactory.getPrimitive((KeysetHandle)this.primaryKeysetHandle));
            int primaryKeyId = keysetHandle.getKeysetInfo().getPrimaryKeyId();
            this.sequenceRepository.updateSymmetricKeyId(String.valueOf(primaryKeyId));
            String aead = Base64.getEncoder().encodeToString(outputStream.toString().getBytes(Charset.defaultCharset()));
            String keyId = String.valueOf(primaryKeyId);
            SymmetricKey symmetricKey = new SymmetricKey(keyId, aead, new Date());
            this.symmetricKeyRepository.save((Object)symmetricKey);
            return symmetricKey;
        }
        catch (IOException | GeneralSecurityException e) {
            throw new RuntimeException(e);
        }
    }

    public SymmetricKey rolloverSymmetricKeys() {
        SymmetricKey symmetricKey = this.generateSymmetricKey();
        this.initializeSymmetricKeys();
        return symmetricKey;
    }

    private RSAKey parseEncryptedRsaKey(SigningKey signingKey) {
        try {
            return RSAKey.parse((String)this.decryptAead(signingKey.getJwk(), signingKey.getSymmetricKeyId()));
        }
        catch (ParseException e) {
            throw new RuntimeException(e);
        }
    }

    private KeysetHandle parseKeysetHandle(SymmetricKey symmetricKey) {
        byte[] decoded = Base64.getDecoder().decode(symmetricKey.getAead());
        try {
            return KeysetHandle.read((KeysetReader)JsonKeysetReader.withBytes((byte[])decoded), (Aead)AeadFactory.getPrimitive((KeysetHandle)this.primaryKeysetHandle));
        }
        catch (IOException | GeneralSecurityException e) {
            throw new RuntimeException(e);
        }
    }

    public String generateAccessToken() {
        return UUID.randomUUID().toString();
    }

    public String generateRefreshToken() {
        return UUID.randomUUID().toString();
    }

    public String generateAuthorizationCode() {
        byte[] verifierBytes = new byte[12];
        this.random.nextBytes(verifierBytes);
        char[] chars = new char[verifierBytes.length];
        for (int i = 0; i < verifierBytes.length; ++i) {
            chars[i] = DEFAULT_CODEC[this.random.nextInt(DEFAULT_CODEC.length)];
        }
        return new String(chars);
    }

    public String generateAccessTokenWithEmbeddedUserInfo(User user, OpenIDClient client) {
        try {
            this.ensureLatestSigningKey();
            return this.doGenerateAccessTokenWithEmbeddedUser(user, client, this.currentSigningKeyId);
        }
        catch (Exception e) {
            throw e instanceof RuntimeException ? (RuntimeException)e : new RuntimeException(e);
        }
    }

    private String doGenerateAccessTokenWithEmbeddedUser(User user, OpenIDClient client, String signingKey) throws JsonProcessingException, GeneralSecurityException, JOSEException {
        String json = this.objectMapper.writeValueAsString((Object)user);
        String encryptedClaims = this.encryptAead(json);
        HashMap<String, String> additionalClaims = new HashMap<String, String>();
        additionalClaims.put("claims", encryptedClaims);
        additionalClaims.put("claim_key_id", this.currentSymmetricKeyId);
        return this.idToken(client, Optional.empty(), additionalClaims, Collections.emptyList(), true, signingKey);
    }

    private String encryptAead(String s) {
        try {
            this.ensureLatestSymmetricKey();
            KeysetHandle keysetHandle = (KeysetHandle)this.safeGet(this.currentSymmetricKeyId, this.keysetHandleMap, () -> this.initializeSymmetricKeys());
            Aead aead = AeadFactory.getPrimitive((KeysetHandle)keysetHandle);
            byte[] src = aead.encrypt(s.getBytes(Charset.defaultCharset()), this.associatedData);
            return Base64.getEncoder().encodeToString(src);
        }
        catch (GeneralSecurityException e) {
            throw new RuntimeException(e);
        }
    }

    public User decryptAccessTokenWithEmbeddedUserInfo(String accessToken) {
        try {
            return this.doDecryptAccessTokenWithEmbeddedUserInfo(accessToken);
        }
        catch (Exception e) {
            throw e instanceof RuntimeException ? (RuntimeException)e : new RuntimeException(e);
        }
    }

    private User doDecryptAccessTokenWithEmbeddedUserInfo(String accessToken) throws ParseException, JOSEException, IOException {
        SignedJWT signedJWT = SignedJWT.parse((String)accessToken);
        Map claims = this.verifyClaims(signedJWT);
        String encryptedClaims = (String)claims.get("claims");
        String keyId = (String)claims.get("claim_key_id");
        String s = this.decryptAead(encryptedClaims, keyId);
        return (User)this.objectMapper.readValue(s, User.class);
    }

    private String decryptAead(String s, String symmetricKeyId) {
        try {
            KeysetHandle keysetHandle = (KeysetHandle)this.safeGet(symmetricKeyId, this.keysetHandleMap, () -> this.ensureLatestSymmetricKey());
            Aead aead = AeadFactory.getPrimitive((KeysetHandle)keysetHandle);
            byte[] decoded = Base64.getDecoder().decode(s);
            return new String(aead.decrypt(decoded, this.associatedData));
        }
        catch (GeneralSecurityException e) {
            throw new RuntimeException(e);
        }
    }

    public String generateIDTokenForTokenEndpoint(Optional<User> user, OpenIDClient client, String nonce, List<String> idTokenClaims, Optional<Long> authorizationTime) throws JOSEException, NoSuchProviderException, NoSuchAlgorithmException {
        HashMap<String, String> additionalClaims = new HashMap<String, String>();
        authorizationTime.ifPresent(time -> additionalClaims.put("auth_time", (String)time));
        if (StringUtils.hasText((String)nonce)) {
            additionalClaims.put("nonce", nonce);
        }
        this.ensureLatestSigningKey();
        return this.idToken(client, user, additionalClaims, idTokenClaims, false, this.currentSigningKeyId);
    }

    public String generateIDTokenForAuthorizationEndpoint(User user, OpenIDClient client, Nonce nonce, ResponseType responseType, String accessToken, List<String> claims, Optional<String> authorizationCode, State state) throws JOSEException, NoSuchProviderException, NoSuchAlgorithmException {
        HashMap<String, Object> additionalClaims = new HashMap<String, Object>();
        additionalClaims.put("auth_time", System.currentTimeMillis() / 1000L);
        if (nonce != null) {
            additionalClaims.put("nonce", nonce.getValue());
        }
        if (AccessTokenHash.isRequiredInIDTokenClaims((ResponseType)responseType)) {
            additionalClaims.put("at_hash", AccessTokenHash.compute((AccessToken)new BearerAccessToken(accessToken), (JWSAlgorithm)signingAlg).getValue());
        }
        if (CodeHash.isRequiredInIDTokenClaims((ResponseType)responseType) && authorizationCode.isPresent()) {
            additionalClaims.put("c_hash", CodeHash.compute((AuthorizationCode)new AuthorizationCode(authorizationCode.get()), (JWSAlgorithm)signingAlg));
        }
        if (state != null && StringUtils.hasText((String)state.getValue())) {
            additionalClaims.put("s_hash", StateHash.compute((State)state, (JWSAlgorithm)signingAlg));
        }
        this.ensureLatestSigningKey();
        return this.idToken(client, Optional.of(user), additionalClaims, claims, false, this.currentSigningKeyId);
    }

    public List<JWK> getAllPublicKeys() {
        this.ensureLatestSigningKey();
        return new ArrayList<JWK>(this.publicKeys);
    }

    private RSAKey generateRsaKey(String keyID) {
        KeyPairGenerator kpg;
        try {
            kpg = KeyPairGenerator.getInstance("RSA", "BC");
        }
        catch (NoSuchAlgorithmException | NoSuchProviderException e) {
            throw new RuntimeException(e);
        }
        kpg.initialize(2048);
        KeyPair keyPair = kpg.generateKeyPair();
        RSAPrivateKey privateKey = (RSAPrivateKey)keyPair.getPrivate();
        RSAPublicKey publicKey = (RSAPublicKey)keyPair.getPublic();
        return new RSAKey.Builder(publicKey).privateKey(privateKey).algorithm((Algorithm)signingAlg).keyID(keyID).build();
    }

    private SigningKey generateEncryptedRsaKey() {
        String keyId = new SimpleDateFormat("yyyy_MM_dd_HH_mm_ss_SSS").format(new Date());
        this.sequenceRepository.updateSigningKeyId(keyId);
        RSAKey rsaKey = this.generateRsaKey(String.format("key_%s", keyId));
        String encryptedKey = this.encryptAead(rsaKey.toJSONString());
        return new SigningKey(rsaKey.getKeyID(), this.currentSymmetricKeyId, encryptedKey, new Date());
    }

    private Map<String, Object> verifyClaims(SignedJWT signedJWT) throws ParseException, JOSEException {
        String keyID = signedJWT.getHeader().getKeyID();
        JWSVerifier verifier = (JWSVerifier)this.safeGet(keyID, this.verifiers, () -> this.initializeSigningKeys());
        if (!signedJWT.verify(verifier)) {
            throw new InvalidSignatureException("Tampered JWT");
        }
        return signedJWT.getJWTClaimsSet().getClaims();
    }

    private String idToken(OpenIDClient client, Optional<User> optionalUser, Map<String, Object> additionalClaims, List<String> idTokenClaims, boolean includeAllowedResourceServers, String signingKey) throws JOSEException {
        ArrayList<String> audiences = new ArrayList<String>();
        audiences.add(client.getClientId());
        if (includeAllowedResourceServers) {
            audiences.addAll(client.getAllowedResourceServers().stream().filter(rsEntityId -> !client.getClientId().equals(rsEntityId)).collect(Collectors.toList()));
        }
        JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder().audience(audiences).expirationTime(Date.from(this.clock.instant().plus((long)client.getAccessTokenValidity(), ChronoUnit.SECONDS))).jwtID(UUID.randomUUID().toString()).issuer(this.issuer).issueTime(Date.from(this.clock.instant())).subject(optionalUser.map(User::getSub).orElse(client.getClientId())).notBeforeTime(new Date(System.currentTimeMillis()));
        if (!CollectionUtils.isEmpty(idTokenClaims) && optionalUser.isPresent()) {
            User user2 = optionalUser.get();
            Map attributes = user2.getAttributes();
            idTokenClaims.forEach(claim -> {
                if (attributes.containsKey(claim)) {
                    builder.claim(claim, attributes.get(claim));
                }
            });
        }
        optionalUser.ifPresent(user -> {
            List validAcrValues = user.getAcrClaims().stream().filter(acrClaim -> this.acrValuesSupported.contains(acrClaim)).collect(Collectors.toList());
            if (CollectionUtils.isEmpty(validAcrValues)) {
                builder.claim("acr", (Object)this.defaultAcrValue);
            } else {
                builder.claim("acr", (Object)String.join((CharSequence)" ", validAcrValues));
            }
        });
        additionalClaims.forEach((arg_0, arg_1) -> ((JWTClaimsSet.Builder)builder).claim(arg_0, arg_1));
        JWTClaimsSet claimsSet = builder.build();
        JWSHeader header = new JWSHeader.Builder(signingAlg).type(JOSEObjectType.JWT).keyID(signingKey).build();
        SignedJWT signedJWT = new SignedJWT(header, claimsSet);
        JWSSigner jswsSigner = (JWSSigner)this.safeGet(signingKey, this.signers, () -> this.initializeSigningKeys());
        signedJWT.sign(jswsSigner);
        return signedJWT.serialize();
    }

    private void ensureLatestSigningKey() {
        if (!this.sequenceRepository.currentSigningKeyId().equals(this.currentSigningKeyId)) {
            this.initializeSigningKeys();
        }
    }

    private void ensureLatestSymmetricKey() {
        if (!this.sequenceRepository.currentSymmetricKeyId().equals(this.currentSymmetricKeyId)) {
            this.initializeSymmetricKeys();
        }
    }

    private RSASSASigner createRSASigner(RSAKey k) {
        try {
            return new RSASSASigner(k);
        }
        catch (JOSEException e) {
            throw new RuntimeException(e);
        }
    }

    private RSASSAVerifier createRSAVerifier(RSAKey k) {
        try {
            return new RSASSAVerifier(k);
        }
        catch (JOSEException e) {
            throw new RuntimeException(e);
        }
    }

    private <T> T safeGet(String k, Map<String, T> map, Runnable runnable) {
        T t = map.get(k);
        if (t == null) {
            runnable.run();
            t = map.get(k);
            if (t == null) {
                throw new IllegalArgumentException(String.format("Map with keys %s does not contain key %s", map.keySet(), k));
            }
        }
        return t;
    }

    public String getCurrentSigningKeyId() {
        return this.currentSigningKeyId;
    }

    public String getCurrentSymmetricKeyId() {
        return this.currentSymmetricKeyId;
    }
}

