/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.record.cipher;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.BulkCipherAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolMessageType;
import de.rub.nds.tlsattacker.core.crypto.cipher.CipherWrapper;
import de.rub.nds.tlsattacker.core.exceptions.CryptoException;
import de.rub.nds.tlsattacker.core.protocol.parser.Parser;
import de.rub.nds.tlsattacker.core.record.BlobRecord;
import de.rub.nds.tlsattacker.core.record.Record;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipher;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.math.BigInteger;
import java.util.Arrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class RecordAEADCipher
extends RecordCipher {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final int AEAD_TAG_LENGTH = 16;
    private static final int AEAD_CCM_8_TAG_LENGTH = 8;
    private final int aeadTagLength;
    private final int aeadExplicitLength;

    public RecordAEADCipher(TlsContext context, KeySet keySet) {
        super(context, keySet);
        ConnectionEndType localConEndType = context.getConnection().getLocalConnectionEndType();
        this.encryptCipher = CipherWrapper.getEncryptionCipher(this.cipherSuite, localConEndType, this.getKeySet());
        this.decryptCipher = CipherWrapper.getDecryptionCipher(this.cipherSuite, localConEndType, this.getKeySet());
        this.aeadTagLength = this.cipherSuite.isCCM_8() ? 8 : 16;
        this.aeadExplicitLength = this.version.isTLS13() ? 0 : AlgorithmResolver.getCipher(this.cipherSuite).getNonceBytesFromRecord();
    }

    public int getAeadSizeIncrease() {
        if (this.version.isTLS13()) {
            return this.aeadTagLength;
        }
        return this.aeadExplicitLength + this.aeadTagLength;
    }

    private byte[] prepareEncryptionGcmNonce(byte[] aeadSalt, byte[] explicitNonce, Record record) {
        byte[] gcmNonce = ArrayConverter.concatenate((byte[][])new byte[][]{aeadSalt, explicitNonce});
        if (this.version.isTLS13() || this.bulkCipherAlg == BulkCipherAlgorithm.CHACHA20_POLY1305) {
            gcmNonce = this.preprocessIv(((BigInteger)record.getSequenceNumber().getValue()).longValue(), gcmNonce);
        }
        record.getComputations().setGcmNonce(gcmNonce);
        gcmNonce = (byte[])record.getComputations().getGcmNonce().getValue();
        return gcmNonce;
    }

    private byte[] prepareEncryptionAeadSalt(Record record) {
        byte[] aeadSalt = this.getKeySet().getWriteIv(this.context.getConnection().getLocalConnectionEndType());
        record.getComputations().setAeadSalt(aeadSalt);
        aeadSalt = (byte[])record.getComputations().getAeadSalt().getValue();
        return aeadSalt;
    }

    private byte[] prepareEncryptionExplicitNonce(Record record) {
        byte[] explicitNonce = this.createExplicitNonce();
        record.getComputations().setExplicitNonce(explicitNonce);
        explicitNonce = (byte[])record.getComputations().getExplicitNonce().getValue();
        return explicitNonce;
    }

    private byte[] createExplicitNonce() {
        byte[] explicitNonce = this.aeadExplicitLength > 0 ? ArrayConverter.longToBytes((long)this.context.getWriteSequenceNumber(), (int)this.aeadExplicitLength) : new byte[this.aeadExplicitLength];
        return explicitNonce;
    }

    @Override
    public void encrypt(Record record) throws CryptoException {
        LOGGER.debug("Encrypting Record");
        record.getComputations().setCipherKey(this.getKeySet().getWriteKey(this.context.getChooser().getConnectionEndType()));
        if (this.version.isTLS13()) {
            int additionalPadding = this.context.getConfig().getDefaultAdditionalPadding();
            if (additionalPadding > 65536) {
                LOGGER.warn("Additional padding is too big. setting it to max possible value");
                additionalPadding = 65536;
            } else if (additionalPadding < 0) {
                LOGGER.warn("Additional padding is negative, setting it to 0");
                additionalPadding = 0;
            }
            record.getComputations().setPadding(new byte[additionalPadding]);
            record.getComputations().setPlainRecordBytes(ArrayConverter.concatenate((byte[][])new byte[][]{(byte[])record.getCleanProtocolMessageBytes().getValue(), {(Byte)record.getContentType().getValue()}, (byte[])record.getComputations().getPadding().getValue()}));
            record.setLength(((byte[])record.getComputations().getPlainRecordBytes().getValue()).length + 16);
            record.setContentType(ProtocolMessageType.APPLICATION_DATA.getValue());
        } else {
            record.getComputations().setPlainRecordBytes((byte[])record.getCleanProtocolMessageBytes().getValue());
        }
        byte[] explicitNonce = this.prepareEncryptionExplicitNonce(record);
        byte[] aeadSalt = this.prepareEncryptionAeadSalt(record);
        byte[] gcmNonce = this.prepareEncryptionGcmNonce(aeadSalt, explicitNonce, record);
        LOGGER.debug("Encrypting AEAD with the following IV: {}", (Object)ArrayConverter.bytesToHexString((byte[])gcmNonce));
        byte[] additionalAuthenticatedData = this.collectAdditionalAuthenticatedData(record, this.context.getChooser().getSelectedProtocolVersion());
        record.getComputations().setAuthenticatedMetaData(additionalAuthenticatedData);
        additionalAuthenticatedData = (byte[])record.getComputations().getAuthenticatedMetaData().getValue();
        LOGGER.debug("Encrypting AEAD with the following AAD: {}", (Object)ArrayConverter.bytesToHexString((byte[])additionalAuthenticatedData));
        byte[] plainBytes = (byte[])record.getComputations().getPlainRecordBytes().getValue();
        byte[] wholeCipherText = this.encryptCipher.encrypt(gcmNonce, this.aeadTagLength * 8, additionalAuthenticatedData, plainBytes);
        if (this.aeadTagLength >= wholeCipherText.length) {
            throw new CryptoException("Could not encrypt data. Supposed Tag is longer than the ciphertext");
        }
        byte[] onlyCiphertext = Arrays.copyOfRange(wholeCipherText, 0, wholeCipherText.length - this.aeadTagLength);
        record.getComputations().setAuthenticatedNonMetaData(onlyCiphertext);
        byte[] authenticationTag = Arrays.copyOfRange(wholeCipherText, wholeCipherText.length - this.aeadTagLength, wholeCipherText.length);
        record.getComputations().setAuthenticationTag(authenticationTag);
        authenticationTag = (byte[])record.getComputations().getAuthenticationTag().getValue();
        record.getComputations().setCiphertext(onlyCiphertext);
        onlyCiphertext = (byte[])record.getComputations().getCiphertext().getValue();
        record.setProtocolMessageBytes(ArrayConverter.concatenate((byte[][])new byte[][]{explicitNonce, onlyCiphertext, authenticationTag}));
        record.getComputations().setAuthenticationTagValid(true);
    }

    @Override
    public void decrypt(Record record) throws CryptoException {
        LOGGER.debug("Decrypting Record");
        record.getComputations().setCipherKey(this.getKeySet().getReadKey(this.context.getChooser().getConnectionEndType()));
        byte[] protocolBytes = (byte[])record.getProtocolMessageBytes().getValue();
        DecryptionParser parser = new DecryptionParser(0, protocolBytes);
        byte[] explicitNonce = parser.parseByteArrayField(this.aeadExplicitLength);
        record.getComputations().setExplicitNonce(explicitNonce);
        explicitNonce = (byte[])record.getComputations().getExplicitNonce().getValue();
        byte[] salt = this.getKeySet().getReadIv(this.context.getConnection().getLocalConnectionEndType());
        record.getComputations().setAeadSalt(salt);
        salt = (byte[])record.getComputations().getAeadSalt().getValue();
        byte[] cipherTextOnly = parser.parseByteArrayField(parser.getBytesLeft() - this.aeadTagLength);
        record.getComputations().setCiphertext(cipherTextOnly);
        record.getComputations().setAuthenticatedNonMetaData((byte[])record.getComputations().getCiphertext().getValue());
        byte[] additionalAuthenticatedData = this.collectAdditionalAuthenticatedData(record, this.context.getChooser().getSelectedProtocolVersion());
        record.getComputations().setAuthenticatedMetaData(additionalAuthenticatedData);
        additionalAuthenticatedData = (byte[])record.getComputations().getAuthenticatedMetaData().getValue();
        LOGGER.debug("Decrypting AEAD with the following AAD: {}", (Object)ArrayConverter.bytesToHexString((byte[])additionalAuthenticatedData));
        byte[] gcmNonce = ArrayConverter.concatenate((byte[][])new byte[][]{salt, explicitNonce});
        if (this.version.isTLS13() || this.bulkCipherAlg == BulkCipherAlgorithm.CHACHA20_POLY1305) {
            gcmNonce = this.preprocessIv(((BigInteger)record.getSequenceNumber().getValue()).longValue(), gcmNonce);
        }
        record.getComputations().setGcmNonce(gcmNonce);
        gcmNonce = (byte[])record.getComputations().getGcmNonce().getValue();
        LOGGER.debug("Decrypting AEAD with the following IV: {}", (Object)ArrayConverter.bytesToHexString((byte[])gcmNonce));
        byte[] authenticationTag = parser.parseByteArrayField(parser.getBytesLeft());
        record.getComputations().setAuthenticationTag(authenticationTag);
        authenticationTag = (byte[])record.getComputations().getAuthenticationTag().getValue();
        try {
            byte[] plainRecordBytes = this.decryptCipher.decrypt(gcmNonce, this.aeadTagLength * 8, additionalAuthenticatedData, ArrayConverter.concatenate((byte[][])new byte[][]{cipherTextOnly, authenticationTag}));
            record.getComputations().setAuthenticationTagValid(true);
            record.getComputations().setPlainRecordBytes(plainRecordBytes);
            plainRecordBytes = (byte[])record.getComputations().getPlainRecordBytes().getValue();
            if (this.version.isTLS13()) {
                int numberOfPaddingBytes = this.countTrailingZeroBytes(plainRecordBytes);
                if (numberOfPaddingBytes == plainRecordBytes.length) {
                    LOGGER.warn("Record contains ONLY padding and no content type. Setting clean bytes == plainbytes");
                    record.setCleanProtocolMessageBytes(plainRecordBytes);
                    return;
                }
                parser = new DecryptionParser(0, plainRecordBytes);
                byte[] cleanBytes = parser.parseByteArrayField(plainRecordBytes.length - numberOfPaddingBytes - 1);
                byte[] contentType = parser.parseByteArrayField(1);
                byte[] padding = parser.parseByteArrayField(numberOfPaddingBytes);
                record.getComputations().setPadding(padding);
                record.setCleanProtocolMessageBytes(cleanBytes);
                record.getComputations().setPadding(cleanBytes);
                record.setContentType(contentType[0]);
                record.setContentMessageType(ProtocolMessageType.getContentType(contentType[0]));
            } else {
                record.setCleanProtocolMessageBytes(plainRecordBytes);
            }
        }
        catch (CryptoException E) {
            LOGGER.warn("Tag invalid", (Throwable)E);
            record.getComputations().setAuthenticationTagValid(false);
            throw new CryptoException(E);
        }
    }

    @Override
    public void encrypt(BlobRecord br) throws CryptoException {
        LOGGER.debug("Encrypting BlobRecord");
        br.setProtocolMessageBytes(this.encryptCipher.encrypt((byte[])br.getCleanProtocolMessageBytes().getValue()));
    }

    @Override
    public void decrypt(BlobRecord br) throws CryptoException {
        LOGGER.debug("Derypting BlobRecord");
        br.setCleanProtocolMessageBytes(this.decryptCipher.decrypt((byte[])br.getProtocolMessageBytes().getValue()));
    }

    private int countTrailingZeroBytes(byte[] plainRecordBytes) {
        int counter = 0;
        for (int i = plainRecordBytes.length - 1; i < plainRecordBytes.length; --i) {
            if (plainRecordBytes[i] == 0) {
                ++counter;
                continue;
            }
            return counter;
        }
        return counter;
    }

    public byte[] preprocessIv(long sequenceNumber, byte[] iv) {
        byte[] padding = new byte[]{0, 0, 0, 0};
        byte[] temp = ArrayConverter.concatenate((byte[][])new byte[][]{padding, ArrayConverter.longToUint64Bytes((long)sequenceNumber)});
        for (int i = 0; i < iv.length; ++i) {
            int n = i;
            temp[n] = (byte)(temp[n] ^ iv[i]);
        }
        return temp;
    }

    class DecryptionParser
    extends Parser<Object> {
        public DecryptionParser(int startposition, byte[] array) {
            super(startposition, array);
        }

        @Override
        public Object parse() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        public byte[] parseByteArrayField(int length) {
            return super.parseByteArrayField(length);
        }

        @Override
        public int getBytesLeft() {
            return super.getBytesLeft();
        }

        @Override
        public int getPointer() {
            return super.getPointer();
        }
    }
}

