import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.util.Arrays;
import java.util.Base64;
public class AESUtil {
  private static final int KEY_SIZE = 256;
  private static final int IV_SIZE = 128;
  private static final String HASH_CIPHER = "AES/GCM/NoPadding";
  private static final String AES = "AES";
  private AESUtil() {
  }
  public static String encrypt(String password, String plainText, boolean urlEncoder) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException {
    SecureRandom secureRandom = new SecureRandom();
    byte[] saltBytes = new byte[8];
    secureRandom.nextBytes(saltBytes);
    byte[] iv = new byte[IV_SIZE / 8];
    secureRandom.nextBytes(iv);
    byte[] key = new byte[KEY_SIZE / 8];
    evpKDF(password.getBytes(StandardCharsets.UTF_8), KEY_SIZE, IV_SIZE, saltBytes, key, iv);
    SecretKey secretKey = new SecretKeySpec(key, AES);
    Cipher cipher = Cipher.getInstance(HASH_CIPHER);
    cipher.init(Cipher.ENCRYPT_MODE, secretKey, new GCMParameterSpec(128, iv));
    byte[] cipherText = cipher.doFinal(plainText.getBytes(StandardCharsets.UTF_8));
    byte[] header = "Salted__".getBytes(StandardCharsets.UTF_8);
    byte[] buffer = new byte[header.length + saltBytes.length + cipherText.length];
    System.arraycopy(header, 0, buffer, 0, header.length);
    System.arraycopy(saltBytes, 0, buffer, header.length, saltBytes.length);
    System.arraycopy(cipherText, 0, buffer, header.length + saltBytes.length, cipherText.length);
    if (urlEncoder) {
      return Base64.getUrlEncoder().withoutPadding().encodeToString(buffer);
    } else {
      return Base64.getEncoder().encodeToString(buffer);
    }
  }
  public static String decrypt(String password, String cipherText, boolean urlEncoder) throws NoSuchAlgorithmException, NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException, InvalidKeyException {
    byte[] ctBytes = urlEncoder ? Base64.getUrlDecoder().decode(cipherText) : Base64.getDecoder().decode(cipherText);
    byte[] salt = Arrays.copyOfRange(ctBytes, 8, 16);
    byte[] ciphertextBytes = Arrays.copyOfRange(ctBytes, 16, ctBytes.length);
    byte[] key = new byte[KEY_SIZE / 8];
    byte[] iv = new byte[IV_SIZE / 8];
    evpKDF(password.getBytes(StandardCharsets.UTF_8), KEY_SIZE, IV_SIZE, salt, key, iv);
    Cipher cipher = Cipher.getInstance(HASH_CIPHER);
    SecretKey secretKey = new SecretKeySpec(key, AES);
    cipher.init(Cipher.DECRYPT_MODE, secretKey, new GCMParameterSpec(128, iv));
    byte[] plainText = cipher.doFinal(ciphertextBytes);
    return new String(plainText);
  }
  private static byte[] evpKDF(byte[] password, int keySize, int ivSize, byte[] salt, byte[] resultKey, byte[] resultIv) throws NoSuchAlgorithmException {
    return evpKDF(password, keySize, ivSize, salt, "MD5", 1, resultKey, resultIv);
  }
  private static byte[] evpKDF(byte[] password, int keySize, int ivSize, byte[] salt, String hash, int iterations, byte[] resultKey, byte[] resultIv) throws NoSuchAlgorithmException {
    keySize = keySize / 32;
    ivSize = ivSize / 32;
    int targetKeySize = keySize + ivSize;
    byte[] derivedBytes = new byte[targetKeySize * 4];
    int numberOfDerivedWords = 0;
    byte[] block = null;
    MessageDigest hasher = MessageDigest.getInstance(hash);
    while (numberOfDerivedWords < targetKeySize) {
      if (block != null) {
        hasher.update(block);
      }
      hasher.update(password);
      block = hasher.digest(salt);
      hasher.reset();
      
      for (int i = 1; i < iterations; i++) {
        block = hasher.digest(block);
        hasher.reset();
      }
      System.arraycopy(block, 0, derivedBytes, numberOfDerivedWords * 4,
          Math.min(block.length, (targetKeySize - numberOfDerivedWords) * 4));
      numberOfDerivedWords += block.length / 4;
    }
    System.arraycopy(derivedBytes, 0, resultKey, 0, keySize * 4);
    System.arraycopy(derivedBytes, keySize * 4, resultIv, 0, ivSize * 4);
    return derivedBytes; 
  }
}