/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.californium.scandium.dtls;

import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.eclipse.californium.elements.util.Bytes;
import org.eclipse.californium.elements.util.DatagramWriter;
import org.eclipse.californium.elements.util.NoPublicAPI;
import org.eclipse.californium.scandium.dtls.AlertMessage;
import org.eclipse.californium.scandium.dtls.ConnectionId;
import org.eclipse.californium.scandium.dtls.ContentType;
import org.eclipse.californium.scandium.dtls.DTLSMessage;
import org.eclipse.californium.scandium.dtls.DTLSSession;
import org.eclipse.californium.scandium.dtls.FragmentedHandshakeMessage;
import org.eclipse.californium.scandium.dtls.HandshakeException;
import org.eclipse.californium.scandium.dtls.HandshakeMessage;
import org.eclipse.californium.scandium.dtls.MultiHandshakeMessage;
import org.eclipse.californium.scandium.dtls.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@NoPublicAPI
public class DTLSFlight {
    private static final Logger LOGGER = LoggerFactory.getLogger(DTLSFlight.class);
    private static final int MAX_TIMEOUT_MILLIS = 60000;
    private final List<Record> records;
    private final List<EpochMessage> dtlsMessages;
    private final DTLSSession session;
    private final int flightNumber;
    private int tries;
    private int timeout = 0;
    private int maxDatagramSize;
    private int maxFragmentSize;
    private int effectiveDatagramSize;
    private boolean useMultiHandshakeMessageRecords;
    private int multiEpoch;
    private boolean multiUseCid;
    private MultiHandshakeMessage multiHandshakeMessage;
    private boolean retransmissionNeeded = false;
    private volatile boolean responseStarted;
    private volatile boolean responseCompleted;
    private ScheduledFuture<?> timeoutTask;

    public DTLSFlight(DTLSSession session, int flightNumber) {
        if (session == null) {
            throw new NullPointerException("Session must not be null");
        }
        if (session.getPeer() == null) {
            throw new NullPointerException("Peer address must not be null");
        }
        this.session = session;
        this.records = new ArrayList<Record>();
        this.dtlsMessages = new ArrayList<EpochMessage>();
        this.retransmissionNeeded = true;
        this.flightNumber = flightNumber;
    }

    public void addDtlsMessage(int epoch, DTLSMessage messageToAdd) {
        if (messageToAdd == null) {
            throw new NullPointerException("message must not be null!");
        }
        this.dtlsMessages.add(new EpochMessage(epoch, messageToAdd));
    }

    public int getNumberOfMessages() {
        return this.dtlsMessages.size();
    }

    public boolean contains(DTLSMessage message) {
        for (EpochMessage epochMessage : this.dtlsMessages) {
            if (!Arrays.equals(message.toByteArray(), epochMessage.message.toByteArray())) continue;
            return true;
        }
        return false;
    }

    protected final void wrapMessage(EpochMessage epochMessage) throws HandshakeException {
        try {
            DTLSMessage message = epochMessage.message;
            switch (message.getContentType()) {
                case HANDSHAKE: {
                    this.wrapHandshakeMessage(epochMessage);
                    break;
                }
                case CHANGE_CIPHER_SPEC: {
                    this.flushMultiHandshakeMessages();
                    this.records.add(new Record(message.getContentType(), epochMessage.epoch, this.session.getSequenceNumber(epochMessage.epoch), message, this.session, false, 0));
                    LOGGER.debug("Add CCS message of {} bytes for [{}]", (Object)message.size(), (Object)message.getPeer());
                    break;
                }
                default: {
                    throw new HandshakeException("Cannot create " + (Object)((Object)message.getContentType()) + " record for flight", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR, this.session.getPeer()));
                }
            }
        }
        catch (GeneralSecurityException e) {
            throw new HandshakeException("Cannot create record", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR, this.session.getPeer()), e);
        }
    }

    private void wrapHandshakeMessage(EpochMessage epochMessage) throws GeneralSecurityException {
        int fragmentLength;
        HandshakeMessage handshakeMessage = (HandshakeMessage)epochMessage.message;
        int messageLength = handshakeMessage.getMessageLength();
        int maxPayloadLength = this.maxDatagramSize - 25;
        int effectiveMaxFragmentSize = this.maxFragmentSize;
        boolean useCid = false;
        if (epochMessage.epoch > 0) {
            maxPayloadLength -= this.session.getMaxCiphertextExpansion();
            ConnectionId connectionId = this.session.getWriteConnectionId();
            if (connectionId != null && !connectionId.isEmpty()) {
                useCid = true;
                maxPayloadLength -= connectionId.length() + 1;
            }
        }
        if (effectiveMaxFragmentSize > maxPayloadLength) {
            effectiveMaxFragmentSize = maxPayloadLength;
        } else {
            this.effectiveDatagramSize = effectiveMaxFragmentSize + (this.maxDatagramSize - maxPayloadLength);
        }
        if (messageLength <= effectiveMaxFragmentSize) {
            if (this.useMultiHandshakeMessageRecords) {
                if (this.multiHandshakeMessage != null) {
                    if (this.multiEpoch == epochMessage.epoch && this.multiUseCid == useCid && this.multiHandshakeMessage.getMessageLength() + handshakeMessage.size() < effectiveMaxFragmentSize) {
                        this.multiHandshakeMessage.add(handshakeMessage);
                        LOGGER.debug("Add multi-handshake-message {} message of {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageLength, handshakeMessage.getPeer()});
                        return;
                    }
                    this.flushMultiHandshakeMessages();
                }
                if (this.multiHandshakeMessage == null && messageLength + 12 < effectiveMaxFragmentSize) {
                    this.multiHandshakeMessage = new MultiHandshakeMessage(this.session.getPeer());
                    this.multiHandshakeMessage.add(handshakeMessage);
                    this.multiEpoch = epochMessage.epoch;
                    this.multiUseCid = useCid;
                    LOGGER.debug("Start multi-handshake-message with {} message of {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageLength, handshakeMessage.getPeer()});
                    return;
                }
            }
            this.records.add(new Record(ContentType.HANDSHAKE, epochMessage.epoch, this.session.getSequenceNumber(epochMessage.epoch), handshakeMessage, this.session, useCid, 0));
            LOGGER.debug("Add {} message of {} bytes for [{}]", new Object[]{handshakeMessage.getMessageType(), messageLength, handshakeMessage.getPeer()});
            return;
        }
        this.flushMultiHandshakeMessages();
        LOGGER.debug("Splitting up {} message of {} bytes for [{}] into multiple fragments of max. {} bytes", new Object[]{handshakeMessage.getMessageType(), messageLength, handshakeMessage.getPeer(), effectiveMaxFragmentSize});
        byte[] messageBytes = handshakeMessage.fragmentToByteArray();
        if (messageBytes.length != messageLength) {
            throw new IllegalStateException("message length " + messageLength + " differs from message " + messageBytes.length + "!");
        }
        int messageSeq = handshakeMessage.getMessageSeq();
        for (int offset = 0; offset < messageLength; offset += fragmentLength) {
            fragmentLength = effectiveMaxFragmentSize;
            if (offset + fragmentLength > messageLength) {
                fragmentLength = messageLength - offset;
            }
            byte[] fragmentBytes = new byte[fragmentLength];
            System.arraycopy(messageBytes, offset, fragmentBytes, 0, fragmentLength);
            FragmentedHandshakeMessage fragmentedMessage = new FragmentedHandshakeMessage(handshakeMessage.getMessageType(), messageLength, messageSeq, offset, fragmentBytes, this.session.getPeer());
            LOGGER.debug("fragment for offset {}, {} bytes", (Object)offset, (Object)fragmentedMessage.size());
            this.records.add(new Record(ContentType.HANDSHAKE, epochMessage.epoch, this.session.getSequenceNumber(epochMessage.epoch), fragmentedMessage, this.session, false, 0));
        }
    }

    private void flushMultiHandshakeMessages() throws GeneralSecurityException {
        if (this.multiHandshakeMessage != null) {
            this.records.add(new Record(ContentType.HANDSHAKE, this.multiEpoch, this.session.getSequenceNumber(this.multiEpoch), this.multiHandshakeMessage, this.session, this.multiUseCid, 0));
            int count = this.multiHandshakeMessage.getNumberOfHandshakeMessages();
            if (count > 1) {
                LOGGER.info("Add {} multi handshake message, epoch {} of {} bytes for [{}]", new Object[]{count, this.multiEpoch, this.multiHandshakeMessage.getMessageLength(), this.multiHandshakeMessage.getPeer()});
            } else {
                LOGGER.debug("Add {} multi handshake message, epoch {} of {} bytes for [{}]", new Object[]{count, this.multiEpoch, this.multiHandshakeMessage.getMessageLength(), this.multiHandshakeMessage.getPeer()});
            }
            this.multiHandshakeMessage = null;
            this.multiEpoch = 0;
            this.multiUseCid = false;
        }
    }

    public List<Record> getRecords(int maxDatagramSize, int maxFragmentSize, boolean useMultiHandshakeMessageRecords) throws HandshakeException {
        try {
            if (this.maxDatagramSize == maxDatagramSize && this.maxFragmentSize == maxFragmentSize && this.useMultiHandshakeMessageRecords == useMultiHandshakeMessageRecords) {
                for (int index = 0; index < this.records.size(); ++index) {
                    Record record = this.records.get(index);
                    int epoch = record.getEpoch();
                    DTLSMessage fragment = record.getFragment();
                    boolean useCid = record.useConnectionId();
                    this.records.set(index, new Record(record.getType(), epoch, this.session.getSequenceNumber(epoch), fragment, this.session, useCid, 0));
                }
            } else {
                this.effectiveDatagramSize = maxDatagramSize;
                this.maxDatagramSize = maxDatagramSize;
                this.maxFragmentSize = maxFragmentSize;
                this.useMultiHandshakeMessageRecords = useMultiHandshakeMessageRecords;
                this.records.clear();
                for (EpochMessage message : this.dtlsMessages) {
                    this.wrapMessage(message);
                }
                this.flushMultiHandshakeMessages();
            }
        }
        catch (GeneralSecurityException e) {
            this.records.clear();
            throw new HandshakeException("Cannot create record", new AlertMessage(AlertMessage.AlertLevel.FATAL, AlertMessage.AlertDescription.INTERNAL_ERROR, this.session.getPeer()), e);
        }
        return this.records;
    }

    public List<DatagramPacket> getDatagrams(int maxDatagramSize, int maxFragmentSize, Boolean useMultiHandshakeMessageRecords, Boolean useMultiRecordMessages, boolean backOff) throws HandshakeException {
        boolean multiRecords;
        DatagramWriter writer = new DatagramWriter(maxDatagramSize);
        ArrayList<DatagramPacket> datagrams = new ArrayList<DatagramPacket>();
        boolean multiHandshakeMessages = Boolean.TRUE.equals(useMultiHandshakeMessageRecords);
        boolean bl = multiRecords = !Boolean.FALSE.equals(useMultiRecordMessages);
        if (backOff) {
            maxDatagramSize = Math.min(512, maxDatagramSize);
        }
        LOGGER.info("Prepare flight {}, using max. datagram size {}, max. fragment size {} [mhm={}, mr={}]", new Object[]{this.flightNumber, maxDatagramSize, maxFragmentSize, multiHandshakeMessages, multiRecords});
        InetSocketAddress peer = this.session.getPeer();
        List<Record> records = this.getRecords(maxDatagramSize, maxFragmentSize, multiHandshakeMessages);
        LOGGER.info("Effective max. datagram size {}", (Object)this.effectiveDatagramSize);
        for (int index = 0; index < records.size(); ++index) {
            int left;
            Record record = records.get(index);
            byte[] recordBytes = record.toByteArray();
            if (recordBytes.length > this.effectiveDatagramSize) {
                LOGGER.error("{} record of {} bytes for peer [{}] exceeds max. datagram size [{}], discarding...", new Object[]{record.getType(), recordBytes.length, peer, this.effectiveDatagramSize});
                continue;
            }
            LOGGER.trace("Sending record of {} bytes to peer [{}]:\n{}", new Object[]{recordBytes.length, peer, record});
            if (multiRecords && record.getType() == ContentType.CHANGE_CIPHER_SPEC && ++index < records.size()) {
                Record finish = records.get(index);
                recordBytes = Bytes.concatenate((byte[])recordBytes, (byte[])finish.toByteArray());
            }
            int n = left = multiRecords && (!backOff || useMultiRecordMessages != null) ? this.effectiveDatagramSize - recordBytes.length : 0;
            if (writer.size() > left) {
                byte[] payload = writer.toByteArray();
                DatagramPacket datagram = new DatagramPacket(payload, payload.length, peer.getAddress(), peer.getPort());
                datagrams.add(datagram);
                LOGGER.debug("Sending datagram of {} bytes to peer [{}]", (Object)payload.length, (Object)peer);
            }
            writer.writeBytes(recordBytes);
        }
        byte[] payload = writer.toByteArray();
        DatagramPacket datagram = new DatagramPacket(payload, payload.length, peer.getAddress(), peer.getPort());
        datagrams.add(datagram);
        LOGGER.debug("Sending datagram of {} bytes to peer [{}]", (Object)payload.length, (Object)peer);
        writer = null;
        return datagrams;
    }

    public int getFlightNumber() {
        return this.flightNumber;
    }

    public int getTries() {
        return this.tries;
    }

    public void incrementTries() {
        ++this.tries;
    }

    public int getTimeout() {
        return this.timeout;
    }

    public void setTimeout(int timeout) {
        this.timeout = timeout;
    }

    public void incrementTimeout() {
        this.timeout = DTLSFlight.incrementTimeout(this.timeout);
    }

    public boolean isRetransmissionNeeded() {
        return this.retransmissionNeeded;
    }

    public void setRetransmissionNeeded(boolean needsRetransmission) {
        this.retransmissionNeeded = needsRetransmission;
    }

    public boolean isResponseStarted() {
        return this.responseStarted;
    }

    public void setResponseStarted() {
        this.responseStarted = true;
    }

    private final void cancelTimeout() {
        if (this.timeoutTask != null) {
            if (!this.timeoutTask.isDone()) {
                this.timeoutTask.cancel(true);
            }
            this.timeoutTask = null;
        }
    }

    public void setResponseCompleted() {
        this.responseCompleted = true;
        this.cancelTimeout();
    }

    public boolean isResponseCompleted() {
        return this.responseCompleted;
    }

    public void scheduleRetransmission(ScheduledExecutorService timer, Runnable task) {
        if (!this.responseCompleted) {
            if (this.isRetransmissionNeeded()) {
                this.cancelTimeout();
                try {
                    this.timeoutTask = timer.schedule(task, (long)this.timeout, TimeUnit.MILLISECONDS);
                    LOGGER.trace("handshake flight to peer {}, retransmission {} ms.", (Object)this.session.getPeer(), (Object)this.timeout);
                }
                catch (RejectedExecutionException ex) {
                    LOGGER.trace("handshake flight stopped by shutdown.");
                }
            } else {
                LOGGER.trace("handshake flight to peer {}, no retransmission!", (Object)this.session.getPeer());
            }
        }
    }

    public static int incrementTimeout(int timeoutMillis) {
        if (timeoutMillis < 60000 && (timeoutMillis *= 2) > 60000) {
            timeoutMillis = 60000;
        }
        return timeoutMillis;
    }

    private static class EpochMessage {
        private final int epoch;
        private final DTLSMessage message;

        private EpochMessage(int epoch, DTLSMessage message) {
            this.epoch = epoch;
            this.message = message;
        }
    }
}

