using System; using System.Collections; using System.IO; using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Crypto.Tls; internal class DtlsReliableHandshake { internal class Message { private readonly int mMessageSeq; private readonly byte mMsgType; private readonly byte[] mBody; public int Seq => mMessageSeq; public byte Type => mMsgType; public byte[] Body => mBody; internal Message(int message_seq, byte msg_type, byte[] body) { mMessageSeq = message_seq; mMsgType = msg_type; mBody = body; } } internal class RecordLayerBuffer : MemoryStream { internal RecordLayerBuffer(int size) : base(size) { } internal void SendToRecordLayer(DtlsRecordLayer recordLayer) { byte[] buffer = GetBuffer(); int len = (int)Length; recordLayer.Send(buffer, 0, len); Platform.Dispose(this); } } internal class Retransmit : DtlsHandshakeRetransmit { private readonly DtlsReliableHandshake mOuter; internal Retransmit(DtlsReliableHandshake outer) { mOuter = outer; } public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len) { mOuter.ProcessRecord(0, epoch, buf, off, len); } } private const int MaxReceiveAhead = 16; private const int MessageHeaderLength = 12; private readonly DtlsRecordLayer mRecordLayer; private TlsHandshakeHash mHandshakeHash; private IDictionary mCurrentInboundFlight = Platform.CreateHashtable(); private IDictionary mPreviousInboundFlight = null; private IList mOutboundFlight = Platform.CreateArrayList(); private bool mSending = true; private int mMessageSeq = 0; private int mNextReceiveSeq = 0; internal TlsHandshakeHash HandshakeHash => mHandshakeHash; internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport) { mRecordLayer = transport; mHandshakeHash = new DeferredHash(); mHandshakeHash.Init(context); } internal void NotifyHelloComplete() { mHandshakeHash = mHandshakeHash.NotifyPrfDetermined(); } internal TlsHandshakeHash PrepareToFinish() { TlsHandshakeHash result = mHandshakeHash; mHandshakeHash = mHandshakeHash.StopTracking(); return result; } internal void SendMessage(byte msg_type, byte[] body) { TlsUtilities.CheckUint24(body.Length); if (!mSending) { CheckInboundFlight(); mSending = true; mOutboundFlight.Clear(); } Message message = new Message(mMessageSeq++, msg_type, body); mOutboundFlight.Add(message); WriteMessage(message); UpdateHandshakeMessagesDigest(message); } internal byte[] ReceiveMessageBody(byte msg_type) { Message message = ReceiveMessage(); if (message.Type != msg_type) { throw new TlsFatalAlert(10); } return message.Body; } internal Message ReceiveMessage() { if (mSending) { mSending = false; PrepareInboundFlight(Platform.CreateHashtable()); } byte[] array = null; int num = 1000; while (true) { try { while (true) { Message pendingMessage = GetPendingMessage(); if (pendingMessage != null) { return pendingMessage; } int receiveLimit = mRecordLayer.GetReceiveLimit(); if (array == null || array.Length < receiveLimit) { array = new byte[receiveLimit]; } int num2 = mRecordLayer.Receive(array, 0, receiveLimit, num); if (num2 >= 0) { if (ProcessRecord(16, mRecordLayer.ReadEpoch, array, 0, num2)) { num = BackOff(num); } continue; } break; } } catch (IOException) { } ResendOutboundFlight(); num = BackOff(num); } } internal void Finish() { DtlsHandshakeRetransmit retransmit = null; if (!mSending) { CheckInboundFlight(); } else { PrepareInboundFlight(null); if (mPreviousInboundFlight != null) { retransmit = new Retransmit(this); } } mRecordLayer.HandshakeSuccessful(retransmit); } internal void ResetHandshakeMessagesDigest() { mHandshakeHash.Reset(); } private int BackOff(int timeoutMillis) { return System.Math.Min(timeoutMillis * 2, 60000); } private void CheckInboundFlight() { foreach (object key in mCurrentInboundFlight.Keys) { int num = (int)key; _ = mNextReceiveSeq; } } private Message GetPendingMessage() { DtlsReassembler dtlsReassembler = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq]; if (dtlsReassembler != null) { byte[] bodyIfComplete = dtlsReassembler.GetBodyIfComplete(); if (bodyIfComplete != null) { mPreviousInboundFlight = null; return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, dtlsReassembler.MsgType, bodyIfComplete)); } } return null; } private void PrepareInboundFlight(IDictionary nextFlight) { ResetAll(mCurrentInboundFlight); mPreviousInboundFlight = mCurrentInboundFlight; mCurrentInboundFlight = nextFlight; } private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len) { bool flag = false; while (len >= 12) { int num = TlsUtilities.ReadUint24(buf, off + 9); int num2 = num + 12; if (len < num2) { break; } int num3 = TlsUtilities.ReadUint24(buf, off + 1); int num4 = TlsUtilities.ReadUint24(buf, off + 6); if (num4 + num > num3) { break; } byte b = TlsUtilities.ReadUint8(buf, off); int num5 = ((b == 20) ? 1 : 0); if (epoch != num5) { break; } int num6 = TlsUtilities.ReadUint16(buf, off + 4); if (num6 < mNextReceiveSeq + windowSize) { if (num6 >= mNextReceiveSeq) { DtlsReassembler dtlsReassembler = (DtlsReassembler)mCurrentInboundFlight[num6]; if (dtlsReassembler == null) { dtlsReassembler = new DtlsReassembler(b, num3); mCurrentInboundFlight[num6] = dtlsReassembler; } dtlsReassembler.ContributeFragment(b, num3, buf, off + 12, num4, num); } else if (mPreviousInboundFlight != null) { DtlsReassembler dtlsReassembler2 = (DtlsReassembler)mPreviousInboundFlight[num6]; if (dtlsReassembler2 != null) { dtlsReassembler2.ContributeFragment(b, num3, buf, off + 12, num4, num); flag = true; } } } off += num2; len -= num2; } bool flag2 = flag && CheckAll(mPreviousInboundFlight); if (flag2) { ResendOutboundFlight(); ResetAll(mPreviousInboundFlight); } return flag2; } private void ResendOutboundFlight() { mRecordLayer.ResetWriteEpoch(); for (int i = 0; i < mOutboundFlight.Count; i++) { WriteMessage((Message)mOutboundFlight[i]); } } private Message UpdateHandshakeMessagesDigest(Message message) { if (message.Type != 0) { byte[] body = message.Body; byte[] array = new byte[12]; TlsUtilities.WriteUint8(message.Type, array, 0); TlsUtilities.WriteUint24(body.Length, array, 1); TlsUtilities.WriteUint16(message.Seq, array, 4); TlsUtilities.WriteUint24(0, array, 6); TlsUtilities.WriteUint24(body.Length, array, 9); mHandshakeHash.BlockUpdate(array, 0, array.Length); mHandshakeHash.BlockUpdate(body, 0, body.Length); } return message; } private void WriteMessage(Message message) { int sendLimit = mRecordLayer.GetSendLimit(); int num = sendLimit - 12; if (num < 1) { throw new TlsFatalAlert(80); } int num2 = message.Body.Length; int num3 = 0; do { int num4 = System.Math.Min(num2 - num3, num); WriteHandshakeFragment(message, num3, num4); num3 += num4; } while (num3 < num2); } private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length) { RecordLayerBuffer recordLayerBuffer = new RecordLayerBuffer(12 + fragment_length); TlsUtilities.WriteUint8(message.Type, recordLayerBuffer); TlsUtilities.WriteUint24(message.Body.Length, recordLayerBuffer); TlsUtilities.WriteUint16(message.Seq, recordLayerBuffer); TlsUtilities.WriteUint24(fragment_offset, recordLayerBuffer); TlsUtilities.WriteUint24(fragment_length, recordLayerBuffer); recordLayerBuffer.Write(message.Body, fragment_offset, fragment_length); recordLayerBuffer.SendToRecordLayer(mRecordLayer); } private static bool CheckAll(IDictionary inboundFlight) { foreach (DtlsReassembler value in inboundFlight.Values) { if (value.GetBodyIfComplete() == null) { return false; } } return true; } private static void ResetAll(IDictionary inboundFlight) { foreach (DtlsReassembler value in inboundFlight.Values) { value.Reset(); } } }