diff --git a/go.mod b/go.mod index 52eb56f4..483c532b 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/miekg/dns v1.1.63 github.com/opencoff/go-sieve v0.2.1 github.com/powerman/check v1.8.0 - github.com/quic-go/quic-go v0.49.0 + github.com/quic-go/quic-go v0.50.0 golang.org/x/crypto v0.33.0 golang.org/x/net v0.35.0 golang.org/x/sys v0.30.0 diff --git a/go.sum b/go.sum index 27ba21e3..627bc5e0 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,8 @@ github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+ github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.49.0 h1:w5iJHXwHxs1QxyBv1EHKuC50GX5to8mJAxvtnttJp94= -github.com/quic-go/quic-go v0.49.0/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= +github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo= +github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= diff --git a/vendor/github.com/quic-go/quic-go/conn_id_manager.go b/vendor/github.com/quic-go/quic-go/conn_id_manager.go index 4030913d..a4fbd93c 100644 --- a/vendor/github.com/quic-go/quic-go/conn_id_manager.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_manager.go @@ -19,6 +19,9 @@ type newConnID struct { type connIDManager struct { queue list.List[newConnID] + highestProbingID uint64 + pathProbing map[pathID]newConnID // initialized lazily + handshakeComplete bool activeSequenceNumber uint64 highestRetired uint64 @@ -76,13 +79,23 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { } // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. - if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { + if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: f.SequenceNumber, }) return nil } + if f.RetirePriorTo != 0 && h.pathProbing != nil { + for id, entry := range h.pathProbing { + if entry.SequenceNumber < f.RetirePriorTo { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: entry.SequenceNumber, + }) + delete(h.pathProbing, id) + } + } + } // Retire elements in the queue. // Doesn't retire the active connection ID. if f.RetirePriorTo > h.highestRetired { @@ -225,6 +238,50 @@ func (h *connIDManager) SetHandshakeComplete() { h.handshakeComplete = true } +// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one). +// Once a connection ID is allocated for a path, it cannot be used for a different path. +// When called with the same pathID, it will return the same connection ID, +// unless the peer requested that this connection ID be retired. +func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) { + h.assertNotClosed() + // if we're using zero-length connection IDs, we don't need to change the connection ID + if h.activeConnectionID.Len() == 0 { + return protocol.ConnectionID{}, true + } + + if h.pathProbing == nil { + h.pathProbing = make(map[pathID]newConnID) + } + entry, ok := h.pathProbing[id] + if ok { + return entry.ConnectionID, true + } + if h.queue.Len() == 0 { + return protocol.ConnectionID{}, false + } + front := h.queue.Remove(h.queue.Front()) + h.pathProbing[id] = front + h.highestProbingID = front.SequenceNumber + return front.ConnectionID, true +} + +func (h *connIDManager) RetireConnIDForPath(pathID pathID) { + h.assertNotClosed() + // if we're using zero-length connection IDs, we don't need to change the connection ID + if h.activeConnectionID.Len() == 0 { + return + } + + entry, ok := h.pathProbing[pathID] + if !ok { + return + } + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: entry.SequenceNumber, + }) + delete(h.pathProbing, pathID) +} + // Using the connIDManager after it has been closed can have disastrous effects: // If the connection ID is rotated, a new entry would be inserted into the packet handler map, // leading to a memory leak of the connection struct. diff --git a/vendor/github.com/quic-go/quic-go/connection.go b/vendor/github.com/quic-go/quic-go/connection.go index 879faec0..9415584d 100644 --- a/vendor/github.com/quic-go/quic-go/connection.go +++ b/vendor/github.com/quic-go/quic-go/connection.go @@ -19,6 +19,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" ) @@ -94,7 +95,6 @@ type connRunner interface { type closeError struct { err error - remote bool immediate bool } @@ -128,6 +128,10 @@ type connection struct { conn sendConn sendQueue sender + // lazily initialzed: most connections never migrate + pathManager *pathManager + largestRcvdAppData protocol.PacketNumber + streamsMap streamManager connIDManager *connIDManager connIDGenerator *connIDGenerator @@ -148,19 +152,21 @@ type connection struct { packer packer mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received - maxPayloadSizeEstimate atomic.Uint32 + currentMTUEstimate atomic.Uint32 initialStream *cryptoStream handshakeStream *cryptoStream oneRTTStream *cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler - receivedPackets chan receivedPacket - sendingScheduled chan struct{} + notifyReceivedPacket chan struct{} + sendingScheduled chan struct{} + receivedPacketMx sync.Mutex + receivedPackets ringbuffer.RingBuffer[receivedPacket] - closeOnce sync.Once // closeChan is used to notify the run loop that it should terminate - closeChan chan closeError + closeChan chan struct{} + closeErr atomic.Pointer[closeError] ctx context.Context ctxCancel context.CancelCauseFunc @@ -280,7 +286,7 @@ var newConnection = func( s.tracer, s.logger, ) - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID) params := &wire.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -392,7 +398,7 @@ var newClientConnection = func( s.tracer, s.logger, ) - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) + s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize)))) oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -450,6 +456,7 @@ var newClientConnection = func( } func (s *connection) preSetup() { + s.largestRcvdAppData = protocol.InvalidPacketNumber s.initialStream = newCryptoStream() s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) @@ -479,8 +486,9 @@ func (s *connection) preSetup() { s.perspective, ) s.framer = newFramer(s.connFlowController) - s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets) - s.closeChan = make(chan closeError, 1) + s.receivedPackets.Init(8) + s.notifyReceivedPacket = make(chan struct{}, 1) + s.closeChan = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1) s.handshakeCompleteChan = make(chan struct{}) @@ -493,23 +501,18 @@ func (s *connection) preSetup() { } // run the connection main loop -func (s *connection) run() error { - var closeErr closeError - defer func() { s.ctxCancel(closeErr.err) }() +func (s *connection) run() (err error) { + defer func() { s.ctxCancel(err) }() defer func() { - // Drain queued packets that will never be processed. - for { - select { - case p, ok := <-s.receivedPackets: - if !ok { - return - } - p.buffer.Decrement() - p.buffer.MaybeRelease() - default: - return - } + // drain queued packets that will never be processed + s.receivedPacketMx.Lock() + defer s.receivedPacketMx.Unlock() + + for !s.receivedPackets.Empty() { + p := s.receivedPackets.PopFront() + p.buffer.Decrement() + p.buffer.MaybeRelease() } }() @@ -536,91 +539,88 @@ func (s *connection) run() error { runLoop: for { if s.framer.QueuedTooManyControlFrames() { - s.closeLocal(&qerr.TransportError{ErrorCode: InternalError}) + s.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}}) + break runLoop } // Close immediately if requested select { - case closeErr = <-s.closeChan: + case <-s.closeChan: break runLoop default: } - s.maybeResetTimer() + // no need to set a timer if we can send packets immediately + if s.pacingDeadline != deadlineSendImmediately { + s.maybeResetTimer() + } - var processedUndecryptablePacket bool + // 1st: handle undecryptable packets, if any. + // This can only occur before completion of the handshake. if len(s.undecryptablePacketsToProcess) > 0 { + var processedUndecryptablePacket bool queue := s.undecryptablePacketsToProcess s.undecryptablePacketsToProcess = nil for _, p := range queue { - if processed := s.handlePacketImpl(p); processed { + processed, err := s.handleOnePacket(p) + if err != nil { + s.setCloseError(&closeError{err: err}) + break runLoop + } + if processed { processedUndecryptablePacket = true } - // Don't set timers and send packets if the packet made us close the connection. - select { - case closeErr = <-s.closeChan: - break runLoop - default: - } + } + if processedUndecryptablePacket { + // if we processed any undecryptable packets, jump to the resetting of the timers directly + continue } } - // If we processed any undecryptable packets, jump to the resetting of the timers directly. - if !processedUndecryptablePacket { + + // 2nd: receive packets. + processed, err := s.handlePackets() // don't check receivedPackets.Len() in the run loop to avoid locking the mutex + if err != nil { + s.setCloseError(&closeError{err: err}) + break runLoop + } + + // We don't need to wait for new events if: + // * we processed packets: we probably need to send an ACK, and potentially more data + // * the pacer allows us to send more packets immediately + shouldProceedImmediately := sendQueueAvailable == nil && (processed || s.pacingDeadline == deadlineSendImmediately) + if !shouldProceedImmediately { + // 3rd: wait for something to happen: + // * closing of the connection + // * timer firing + // * sending scheduled + // * send queue available + // * received packets select { - case closeErr = <-s.closeChan: + case <-s.closeChan: break runLoop case <-s.timer.Chan(): s.timer.SetRead() - // We do all the interesting stuff after the switch statement, so - // nothing to see here. case <-s.sendingScheduled: - // We do all the interesting stuff after the switch statement, so - // nothing to see here. case <-sendQueueAvailable: - case firstPacket := <-s.receivedPackets: - wasProcessed := s.handlePacketImpl(firstPacket) - // Don't set timers and send packets if the packet made us close the connection. - select { - case closeErr = <-s.closeChan: + case <-s.notifyReceivedPacket: + wasProcessed, err := s.handlePackets() + if err != nil { + s.setCloseError(&closeError{err: err}) break runLoop - default: } - if s.handshakeComplete { - // Now process all packets in the receivedPackets channel. - // Limit the number of packets to the length of the receivedPackets channel, - // so we eventually get a chance to send out an ACK when receiving a lot of packets. - numPackets := len(s.receivedPackets) - receiveLoop: - for i := 0; i < numPackets; i++ { - select { - case p := <-s.receivedPackets: - if processed := s.handlePacketImpl(p); processed { - wasProcessed = true - } - select { - case closeErr = <-s.closeChan: - break runLoop - default: - } - default: - break receiveLoop - } - } - } - // Only reset the timers if this packet was actually processed. - // This avoids modifying any state when handling undecryptable packets, - // which could be injected by an attacker. + // if we processed any undecryptable packets, jump to the resetting of the timers directly if !wasProcessed { continue } } } + // Check for loss detection timeout. + // This could cause packets to be declared lost, and retransmissions to be enqueued. now := time.Now() if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { - // This could cause packets to be retransmitted. - // Check it before trying to send packets. if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil { - s.closeLocal(err) + s.setCloseError(&closeError{err: err}) + break runLoop } } @@ -631,35 +631,46 @@ runLoop: s.keepAlivePingSent = true } else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() { s.destroyImpl(qerr.ErrHandshakeTimeout) - continue + break runLoop } else { idleTimeoutStartTime := s.idleTimeoutStartTime() if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || (s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) { s.destroyImpl(qerr.ErrIdleTimeout) - continue + break runLoop } } if s.sendQueue.WouldBlock() { - // The send queue is still busy sending out packets. - // Wait until there's space to enqueue new packets. + // The send queue is still busy sending out packets. Wait until there's space to enqueue new packets. sendQueueAvailable = s.sendQueue.Available() + // Cancel the pacing timer, as we can't send any more packets until the send queue is available again. + s.pacingDeadline = time.Time{} continue } + + if s.closeErr.Load() != nil { + break runLoop + } + if err := s.triggerSending(now); err != nil { - s.closeLocal(err) + s.setCloseError(&closeError{err: err}) + break runLoop } if s.sendQueue.WouldBlock() { + // The send queue is still busy sending out packets. Wait until there's space to enqueue new packets. sendQueueAvailable = s.sendQueue.Available() + // Cancel the pacing timer, as we can't send any more packets until the send queue is available again. + s.pacingDeadline = time.Time{} } else { sendQueueAvailable = nil } } + closeErr := s.closeErr.Load() s.cryptoStreamHandler.Close() s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE - s.handleCloseError(&closeErr) + s.handleCloseError(closeErr) if s.tracer != nil && s.tracer.Close != nil { if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { s.tracer.Close() @@ -802,17 +813,60 @@ func (s *connection) handleHandshakeConfirmed(now time.Time) error { return nil } -func (s *connection) handlePacketImpl(rp receivedPacket) bool { +func (s *connection) handlePackets() (wasProcessed bool, _ error) { + // Now process all packets in the receivedPackets channel. + // Limit the number of packets to the length of the receivedPackets channel, + // so we eventually get a chance to send out an ACK when receiving a lot of packets. + s.receivedPacketMx.Lock() + numPackets := s.receivedPackets.Len() + if numPackets == 0 { + s.receivedPacketMx.Unlock() + return false, nil + } + + var hasMorePackets bool + for i := 0; i < numPackets; i++ { + if i > 0 { + s.receivedPacketMx.Lock() + } + p := s.receivedPackets.PopFront() + hasMorePackets = !s.receivedPackets.Empty() + s.receivedPacketMx.Unlock() + + processed, err := s.handleOnePacket(p) + if err != nil { + return false, err + } + if processed { + wasProcessed = true + } + if !hasMorePackets { + break + } + // only process a single packet at a time before handshake completion + if !s.handshakeComplete { + break + } + } + if hasMorePackets { + select { + case s.notifyReceivedPacket <- struct{}{}: + default: + } + } + return wasProcessed, nil +} + +func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ error) { s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime) if wire.IsVersionNegotiationPacket(rp.data) { s.handleVersionNegotiationPacket(rp) - return false + return false, nil } var counter uint8 var lastConnID protocol.ConnectionID - var processed bool data := rp.data p := rp for len(data) > 0 { @@ -872,26 +926,34 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { p.data = packetData - if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed { - processed = true + processed, err := s.handleLongHeaderPacket(p, hdr) + if err != nil { + return false, err + } + if processed { + wasProcessed = true } data = rest } else { if counter > 0 { p.buffer.Split() } - if wasProcessed := s.handleShortHeaderPacket(p); wasProcessed { - processed = true + processed, err := s.handleShortHeaderPacket(p) + if err != nil { + return false, err + } + if processed { + wasProcessed = true } break } } p.buffer.MaybeRelease() - return processed + return wasProcessed, nil } -func (s *connection) handleShortHeaderPacket(p receivedPacket) bool { +func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed bool, _ error) { var wasQueued bool defer func() { @@ -904,13 +966,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool { destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError) - return false + return false, nil } pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) if err != nil { - wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT) - return false + wasQueued, err = s.handleUnpackError(err, p, logging.PacketType1RTT) + return false, err } + s.largestRcvdAppData = max(s.largestRcvdAppData, pn) if s.logger.Debug() { s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID) @@ -922,7 +985,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool { if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate) } - return false + return false, nil } var log func([]logging.Frame) @@ -941,14 +1004,58 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket) bool { ) } } - if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil { - s.closeLocal(err) - return false + isNonProbing, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log) + if err != nil { + return false, err } - return true + + // In RFC 9000, only the client can migrate between paths. + if s.perspective == protocol.PerspectiveClient { + return true, nil + } + + var shouldSwitchPath bool + if pn == s.largestRcvdAppData && !addrsEqual(p.remoteAddr, s.RemoteAddr()) { + if s.pathManager == nil { + s.pathManager = newPathManager( + s.connIDManager.GetConnIDForPath, + s.connIDManager.RetireConnIDForPath, + s.logger, + ) + } + var destConnID protocol.ConnectionID + var pathChallenge ackhandler.Frame + destConnID, pathChallenge, shouldSwitchPath = s.pathManager.HandlePacket(p, isNonProbing) + if pathChallenge.Frame != nil { + probe, buf, err := s.packer.PackPathProbePacket(destConnID, pathChallenge, s.version) + if err != nil { + return false, err + } + s.logger.Debugf("sending path probe packet to %s", p.remoteAddr) + s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false) + s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime) + s.sendQueue.SendProbe(buf, p.remoteAddr) + } + } + + if shouldSwitchPath { + s.pathManager.SwitchToPath(p.remoteAddr) + s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize)) + maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize) + if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize { + maxPacketSize = s.peerParams.MaxUDPPayloadSize + } + s.mtuDiscoverer.Reset( + p.rcvTime, + protocol.ByteCount(s.config.InitialPacketSize), + maxPacketSize, + ) + s.conn.ChangeRemoteAddr(p.remoteAddr, p.info) + } + return true, nil } -func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) (wasProcessed bool, _ error) { var wasQueued bool defer func() { @@ -959,7 +1066,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) }() if hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(hdr, p.data, p.rcvTime) + return s.handleRetryPacket(hdr, p.data, p.rcvTime), nil } // The server can change the source connection ID with the first Handshake packet. @@ -969,20 +1076,20 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) - return false + return false, nil } // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) } - return false + return false, nil } packet, err := s.unpacker.UnpackLongHeader(hdr, p.data) if err != nil { - wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr)) - return false + wasQueued, err = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr)) + return false, err } if s.logger.Debug() { @@ -995,39 +1102,40 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate) } - return false + return false, nil } if err := s.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { - s.closeLocal(err) - return false + return false, err } - return true + return true, nil } -func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { +func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool, _ error) { switch err { case handshake.ErrKeysDropped: if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) + return false, nil case handshake.ErrKeysNotYetAvailable: // Sealer for this encryption level not yet available. // Try again later. s.tryQueueingUndecryptablePacket(p, pt) - return true + return true, nil case wire.ErrInvalidReservedBits: - s.closeLocal(&qerr.TransportError{ + return false, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: err.Error(), - }) + } case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) + return false, nil default: var headerErr *headerParseError if errors.As(err, &headerErr) { @@ -1036,13 +1144,12 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) - } else { - // This is an error returned by the AEAD (other than ErrDecryptionFailed). - // For example, a PROTOCOL_VIOLATION due to key updates. - s.closeLocal(err) + return false, nil } + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + return false, err } - return false } func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { @@ -1219,13 +1326,17 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false + if packet.hdr.Type == protocol.PacketType0RTT { + s.largestRcvdAppData = max(s.largestRcvdAppData, packet.hdr.PacketNumber) + } + var log func([]logging.Frame) if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } } - isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime) + isAckEliciting, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime) if err != nil { return err } @@ -1239,16 +1350,19 @@ func (s *connection) handleUnpackedShortHeaderPacket( ecn protocol.ECN, rcvTime time.Time, log func([]logging.Frame), -) error { +) (isNonProbing bool, _ error) { s.lastPacketReceivedTime = rcvTime s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false - isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime) + isAckEliciting, isNonProbing, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime) if err != nil { - return err + return false, err } - return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) + if err := s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil { + return false, err + } + return isNonProbing, nil } func (s *connection) handleFrames( @@ -1257,7 +1371,7 @@ func (s *connection) handleFrames( encLevel protocol.EncryptionLevel, log func([]logging.Frame), rcvTime time.Time, -) (isAckEliciting bool, _ error) { +) (isAckEliciting, isNonProbing bool, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. var frames []logging.Frame @@ -1269,7 +1383,7 @@ func (s *connection) handleFrames( for len(data) > 0 { l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) if err != nil { - return false, err + return false, false, err } data = data[l:] if frame == nil { @@ -1278,6 +1392,9 @@ func (s *connection) handleFrames( if ackhandler.IsFrameAckEliciting(frame) { isAckEliciting = true } + if !wire.IsProbingFrame(frame) { + isNonProbing = true + } if log != nil { frames = append(frames, toLoggingFrame(frame)) } @@ -1288,7 +1405,7 @@ func (s *connection) handleFrames( } if err := s.handleFrame(frame, encLevel, destConnID, rcvTime); err != nil { if log == nil { - return false, err + return false, false, err } // If we're logging, we need to keep parsing (but not handling) all frames. handleErr = err @@ -1298,7 +1415,7 @@ func (s *connection) handleFrames( if log != nil { log(frames) if handleErr != nil { - return false, handleErr + return false, false, handleErr } } @@ -1308,10 +1425,9 @@ func (s *connection) handleFrames( // and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame. if !handshakeWasComplete && s.handshakeComplete { if err := s.handleHandshakeComplete(rcvTime); err != nil { - return false, err + return false, false, err } } - return } @@ -1331,7 +1447,7 @@ func (s *connection) handleFrame( case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel, rcvTime) case *wire.ConnectionCloseFrame: - s.handleConnectionCloseFrame(frame) + err = s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: err = s.handleResetStreamFrame(frame, rcvTime) case *wire.MaxDataFrame: @@ -1350,11 +1466,7 @@ func (s *connection) handleFrame( case *wire.PathChallengeFrame: s.handlePathChallengeFrame(frame) case *wire.PathResponseFrame: - // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs - err = &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "unexpected PATH_RESPONSE frame", - } + err = s.handlePathResponseFrame(frame) case *wire.NewTokenFrame: err = s.handleNewTokenFrame(frame) case *wire.NewConnectionIDFrame: @@ -1373,32 +1485,39 @@ func (s *connection) handleFrame( // handlePacket is called by the server with a new packet func (s *connection) handlePacket(p receivedPacket) { + s.receivedPacketMx.Lock() // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxConnUnprocessedPackets - select { - case s.receivedPackets <- p: - default: + if s.receivedPackets.Len() >= protocol.MaxConnUnprocessedPackets { if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } + s.receivedPacketMx.Unlock() + return + } + s.receivedPackets.PushBack(p) + s.receivedPacketMx.Unlock() + + select { + case s.notifyReceivedPacket <- struct{}{}: + default: } } -func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { +func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error { if frame.IsApplicationError { - s.closeRemote(&qerr.ApplicationError{ + return &qerr.ApplicationError{ Remote: true, ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode), ErrorMessage: frame.ReasonPhrase, - }) - return + } } - s.closeRemote(&qerr.TransportError{ + return &qerr.TransportError{ Remote: true, ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), FrameType: frame.FrameType, ErrorMessage: frame.ReasonPhrase, - }) + } } func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error { @@ -1434,8 +1553,8 @@ func (s *connection) handleHandshakeEvents(now time.Time) error { s.restoreTransportParameters(ev.TransportParameters) close(s.earlyConnReadyChan) case handshake.EventReceivedReadKeys: - // Queue all packets for decryption that have been undecryptable so far. - s.undecryptablePacketsToProcess = s.undecryptablePackets + // queue all previously undecryptable packets + s.undecryptablePacketsToProcess = append(s.undecryptablePacketsToProcess, s.undecryptablePackets...) s.undecryptablePackets = nil case handshake.EventDiscard0RTTKeys: err = s.dropEncryptionLevel(protocol.Encryption0RTT, now) @@ -1514,8 +1633,21 @@ func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error return nil } -func (s *connection) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { - s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) +func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) { + s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data}) +} + +func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error { + s.logger.Debugf("received PATH_RESPONSE frame: %v", f.Data) + if s.pathManager == nil { + // since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "unexpected PATH_RESPONSE frame", + } + } + s.pathManager.HandlePathResponseFrame(f) + return nil } func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { @@ -1568,6 +1700,13 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr return err } } + // If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate. + if s.mtuDiscoverer != nil { + if mtu := s.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(s.currentMTUEstimate.Load()) { + s.currentMTUEstimate.Store(uint32(mtu)) + s.sentPacketHandler.SetMaxDatagramSize(mtu) + } + } return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } @@ -1582,16 +1721,17 @@ func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { return nil } +func (s *connection) setCloseError(e *closeError) { + s.closeErr.CompareAndSwap(nil, e) + select { + case s.closeChan <- struct{}{}: + default: + } +} + // closeLocal closes the connection and send a CONNECTION_CLOSE containing the error func (s *connection) closeLocal(e error) { - s.closeOnce.Do(func() { - if e == nil { - s.logger.Infof("Closing connection.") - } else { - s.logger.Errorf("Closing connection with error: %s", e) - } - s.closeChan <- closeError{err: e, immediate: false, remote: false} - }) + s.setCloseError(&closeError{err: e, immediate: false}) } // destroy closes the connection without sending the error on the wire @@ -1601,21 +1741,7 @@ func (s *connection) destroy(e error) { } func (s *connection) destroyImpl(e error) { - s.closeOnce.Do(func() { - if nerr, ok := e.(net.Error); ok && nerr.Timeout() { - s.logger.Errorf("Destroying connection: %s", e) - } else { - s.logger.Errorf("Destroying connection with error: %s", e) - } - s.closeChan <- closeError{err: e, immediate: true, remote: false} - }) -} - -func (s *connection) closeRemote(e error) { - s.closeOnce.Do(func() { - s.logger.Errorf("Peer closed connection with error: %s", e) - s.closeChan <- closeError{err: e, immediate: true, remote: true} - }) + s.setCloseError(&closeError{err: e, immediate: true}) } func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { @@ -1633,13 +1759,25 @@ func (s *connection) closeWithTransportError(code TransportErrorCode) { } func (s *connection) handleCloseError(closeErr *closeError) { + if closeErr.immediate { + if nerr, ok := closeErr.err.(net.Error); ok && nerr.Timeout() { + s.logger.Errorf("Destroying connection: %s", closeErr.err) + } else { + s.logger.Errorf("Destroying connection with error: %s", closeErr.err) + } + } else { + if closeErr.err == nil { + s.logger.Infof("Closing connection.") + } else { + s.logger.Errorf("Closing connection with error: %s", closeErr.err) + } + } + e := closeErr.err if e == nil { e = &qerr.ApplicationError{} } else { - defer func() { - closeErr.err = e - }() + defer func() { closeErr.err = e }() } var ( @@ -1649,14 +1787,17 @@ func (s *connection) handleCloseError(closeErr *closeError) { applicationErr *ApplicationError transportErr *TransportError ) + var isRemoteClose bool switch { case errors.Is(e, qerr.ErrIdleTimeout), errors.Is(e, qerr.ErrHandshakeTimeout), errors.As(e, &statelessResetErr), errors.As(e, &versionNegotiationErr), - errors.As(e, &recreateErr), - errors.As(e, &applicationErr), - errors.As(e, &transportErr): + errors.As(e, &recreateErr): + case errors.As(e, &applicationErr): + isRemoteClose = applicationErr.Remote + case errors.As(e, &transportErr): + isRemoteClose = transportErr.Remote case closeErr.immediate: e = closeErr.err default: @@ -1682,7 +1823,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { } // If this is a remote close we're done here - if closeErr.remote { + if isRemoteClose { s.connIDGenerator.ReplaceWithClosed(nil) return } @@ -1831,7 +1972,6 @@ func (s *connection) applyTransportParameters() { s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), maxPacketSize, - s.onMTUIncreased, s.tracer, ) } @@ -1952,7 +2092,10 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return nil } // Prioritize receiving of packets over sending out more packets. - if len(s.receivedPackets) > 0 { + s.receivedPacketMx.Lock() + hasPackets := !s.receivedPackets.Empty() + s.receivedPacketMx.Unlock() + if hasPackets { s.pacingDeadline = deadlineSendImmediately return nil } @@ -2010,7 +2153,10 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { } // Prioritize receiving of packets over sending out more packets. - if len(s.receivedPackets) > 0 { + s.receivedPacketMx.Lock() + hasPackets := !s.receivedPackets.Empty() + s.receivedPacketMx.Unlock() + if hasPackets { s.pacingDeadline = deadlineSendImmediately return nil } @@ -2076,7 +2222,7 @@ func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) + packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2087,7 +2233,7 @@ func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time if packet == nil { s.retransmissionQueue.AddPing(encLevel) var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), now, s.version) + packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version) if err != nil { return err } @@ -2113,6 +2259,21 @@ func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize proto } func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) { + if p.IsPathProbePacket { + s.sentPacketHandler.SentPacket( + now, + p.PacketNumber, + protocol.InvalidPacketNumber, + p.StreamFrames, + p.Frames, + protocol.Encryption1RTT, + ecn, + p.Length, + p.IsPathMTUProbePacket, + true, + ) + return + } if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -2121,7 +2282,18 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn pr if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket( + now, + p.PacketNumber, + largestAcked, + p.StreamFrames, + p.Frames, + protocol.Encryption1RTT, + ecn, + p.Length, + p.IsPathMTUProbePacket, + false, + ) s.connIDManager.SentPacket() } @@ -2135,7 +2307,18 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot if p.ack != nil { largestAcked = p.ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) + s.sentPacketHandler.SentPacket( + now, + p.header.PacketNumber, + largestAcked, + p.streamFrames, + p.frames, + p.EncryptionLevel(), + ecn, + p.length, + false, + false, + ) if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake && !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. @@ -2153,7 +2336,18 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket( + now, + p.PacketNumber, + largestAcked, + p.StreamFrames, + p.Frames, + protocol.Encryption1RTT, + ecn, + p.Length, + p.IsPathMTUProbePacket, + false, + ) } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer, 0, ecn) @@ -2299,11 +2493,6 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { s.framer.RemoveActiveStream(id) } -func (s *connection) onMTUIncreased(mtu protocol.ByteCount) { - s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu))) - s.sentPacketHandler.SetMaxDatagramSize(mtu) -} - func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") @@ -2314,7 +2503,7 @@ func (s *connection) SendDatagram(p []byte) error { // Under many circumstances we could send a few more bytes. maxDataLen := min( f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version), - protocol.ByteCount(s.maxPayloadSizeEstimate.Load()), + protocol.ByteCount(s.currentMTUEstimate.Load()), ) if protocol.ByteCount(len(p)) > maxDataLen { return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} diff --git a/vendor/github.com/quic-go/quic-go/http3/server.go b/vendor/github.com/quic-go/quic-go/http3/server.go index 097a8005..1479609c 100644 --- a/vendor/github.com/quic-go/quic-go/http3/server.go +++ b/vendor/github.com/quic-go/quic-go/http3/server.go @@ -684,7 +684,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat if logger == nil { logger = slog.Default() } - logger.Error("http: panic serving", "arg", p, "trace", string(buf)) + logger.Error("http3: panic serving", "arg", p, "trace", string(buf)) } }() handler.ServeHTTP(r, req) @@ -694,18 +694,6 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat return } - // only write response when there is no panic - if !panicked { - // response not written to the client yet, set Content-Length - if !r.headerWritten { - if _, haveCL := r.header["Content-Length"]; !haveCL { - r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) - } - } - r.Flush() - r.flushTrailers() - } - // abort the stream when there is a panic if panicked { str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) @@ -713,9 +701,17 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat return } + // response not written to the client yet, set Content-Length + if !r.headerWritten { + if _, haveCL := r.header["Content-Length"]; !haveCL { + r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) + } + } + r.Flush() + r.flushTrailers() + // If the EOF was read by the handler, CancelRead() is a no-op. str.CancelRead(quic.StreamErrorCode(ErrCodeNoError)) - str.Close() } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go index acf95426..5fcce44d 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go @@ -10,7 +10,7 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) @@ -34,6 +34,8 @@ type SentPacketHandler interface { GetLossDetectionTimeout() time.Time OnLossDetectionTimeout(now time.Time) error + + MigratedPath(now time.Time, initialMaxPacketSize protocol.ByteCount) } type sentPacketTracker interface { diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go index 5f43689b..c634939a 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go @@ -22,10 +22,11 @@ type packet struct { includedInBytesInFlight bool declaredLost bool skippedPacket bool + isPathProbePacket bool } func (p *packet) outstanding() bool { - return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket + return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket && !p.isPathProbePacket } var packetPool = sync.Pool{New: func() any { return &packet{} }} diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go index 5276fe19..7c3cf892 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go @@ -27,6 +27,9 @@ const ( maxPTODuration = 60 * time.Second ) +// Path probe packets are declared lost after this time. +const pathProbePacketLossTimeout = time.Second + type packetNumberSpace struct { history sentPacketHistory pns packetNumberGenerator @@ -174,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t if pnSpace == nil { return } - pnSpace.history.Iterate(func(p *packet) (bool, error) { + for p := range pnSpace.history.Packets() { h.removeFromBytesInFlight(p) - return true, nil - }) + } } // drop the packet history //nolint:exhaustive // Not every packet number space can be dropped. @@ -194,14 +196,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t // and not when the client drops 0-RTT keys when the handshake completes. // When 0-RTT is rejected, all application data sent so far becomes invalid. // Delete the packets from the history and remove them from bytes_in_flight. - h.appDataPackets.history.Iterate(func(p *packet) (bool, error) { + for p := range h.appDataPackets.history.Packets() { if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket { - return false, nil + break } h.removeFromBytesInFlight(p) h.appDataPackets.history.Remove(p.PacketNumber) - return true, nil - }) + } default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } @@ -249,11 +250,12 @@ func (h *sentPacketHandler) SentPacket( ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, + isPathProbePacket bool, ) { h.bytesSent += size pnSpace := h.getPacketNumberSpace(encLevel) - if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { + if h.logger.Debug() && (pnSpace.history.HasOutstandingPackets() || pnSpace.history.HasOutstandingPathProbes()) { for p := max(0, pnSpace.largestSent+1); p < pn; p++ { h.logger.Debugf("Skipping packet number %d", p) } @@ -262,6 +264,18 @@ func (h *sentPacketHandler) SentPacket( pnSpace.largestSent = pn isAckEliciting := len(streamFrames) > 0 || len(frames) > 0 + if isPathProbePacket { + p := getPacket() + p.SendTime = t + p.PacketNumber = pn + p.EncryptionLevel = encLevel + p.Length = size + p.Frames = frames + p.isPathProbePacket = true + pnSpace.history.SentPathProbePacket(p) + h.setLossDetectionTimer(t) + return + } if isAckEliciting { pnSpace.lastAckElicitingPacketTime = t h.bytesInFlight += size @@ -341,7 +355,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } // update the RTT, if the largest acked is newly acknowledged if len(ackedPackets) > 0 { - if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() { + if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() && !p.isPathProbePacket { // don't use the ack delay for Initial and Handshake packets var ackDelay time.Duration if encLevel == protocol.Encryption1RTT { @@ -365,8 +379,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked) - if err := h.detectLostPackets(rcvTime, encLevel); err != nil { - return false, err + h.detectLostPackets(rcvTime, encLevel) + if encLevel == protocol.Encryption1RTT { + h.detectLostPathProbes(rcvTime) } var acked1RTTPacket bool for _, p := range ackedPackets { @@ -377,7 +392,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En acked1RTTPacket = true } h.removeFromBytesInFlight(p) - putPacket(p) + if !p.isPathProbePacket { + putPacket(p) + } } // After this point, we must not use ackedPackets any longer! // We've already returned the buffers. @@ -411,14 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL ackRangeIndex := 0 lowestAcked := ack.LowestAcked() largestAcked := ack.LargestAcked() - err := pnSpace.history.Iterate(func(p *packet) (bool, error) { - // Ignore packets below the lowest acked + for p := range pnSpace.history.Packets() { + // ignore packets below the lowest acked if p.PacketNumber < lowestAcked { - return true, nil + continue } - // Break after largest acked is reached if p.PacketNumber > largestAcked { - return false, nil + break } if ack.HasMissingRanges() { @@ -430,21 +446,28 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range - return true, nil + continue } if p.PacketNumber > ackRange.Largest { - return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) + return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) } } if p.skippedPacket { - return false, &qerr.TransportError{ + return nil, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), } } + if p.isPathProbePacket { + probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber) + if probePacket == nil { + panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber)) + } + h.ackedPackets = append(h.ackedPackets, probePacket) + continue + } h.ackedPackets = append(h.ackedPackets, p) - return true, nil - }) + } if h.logger.Debug() && len(h.ackedPackets) > 0 { pns := make([]protocol.PacketNumber, len(h.ackedPackets)) for i, p := range h.ackedPackets { @@ -475,8 +498,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) } } - - return h.ackedPackets, err + return h.ackedPackets, nil } func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { @@ -507,7 +529,7 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration } // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime -func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { +func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel) { // We only send application data probe packets once the handshake is confirmed, // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { @@ -516,32 +538,35 @@ func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, en } t := now.Add(h.getScaledPTO(false)) if h.initialPackets != nil { - return t, protocol.EncryptionInitial, true + return t, protocol.EncryptionInitial } - return t, protocol.EncryptionHandshake, true + return t, protocol.EncryptionHandshake } - if h.initialPackets != nil { + if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() && + !h.initialPackets.lastAckElicitingPacketTime.IsZero() { encLevel = protocol.EncryptionInitial if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { pto = t.Add(h.getScaledPTO(false)) } } - if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { + if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() && + !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.EncryptionHandshake } } - if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { + if h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets() && + !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.Encryption1RTT } } - return pto, encLevel, true + return pto, encLevel } func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { @@ -573,8 +598,8 @@ func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) { func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer { // cancel the alarm if no packets are outstanding - if h.peerCompletedAddressValidation && - !h.hasOutstandingCryptoPackets() && !h.appDataPackets.history.HasOutstandingPackets() { + if h.peerCompletedAddressValidation && !h.hasOutstandingCryptoPackets() && + !h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.history.HasOutstandingPathProbes() { return alarmTimer{} } @@ -583,28 +608,62 @@ func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer { return alarmTimer{} } + var pathProbeLossTime time.Time + if h.appDataPackets.history.HasOutstandingPathProbes() { + if p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil { + pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout) + } + } + // early retransmit timer or time loss detection lossTime, encLevel := h.getLossTimeAndSpace() - if !lossTime.IsZero() { + if !lossTime.IsZero() && (pathProbeLossTime.IsZero() || lossTime.Before(pathProbeLossTime)) { return alarmTimer{ Time: lossTime, TimerType: logging.TimerTypeACK, EncryptionLevel: encLevel, } } - - ptoTime, encLevel, ok := h.getPTOTimeAndSpace(now) - if !ok { - return alarmTimer{} + ptoTime, encLevel := h.getPTOTimeAndSpace(now) + if !ptoTime.IsZero() && (pathProbeLossTime.IsZero() || ptoTime.Before(pathProbeLossTime)) { + return alarmTimer{ + Time: ptoTime, + TimerType: logging.TimerTypePTO, + EncryptionLevel: encLevel, + } } - return alarmTimer{ - Time: ptoTime, - TimerType: logging.TimerTypePTO, - EncryptionLevel: encLevel, + if !pathProbeLossTime.IsZero() { + return alarmTimer{ + Time: pathProbeLossTime, + TimerType: logging.TimerTypePathProbe, + EncryptionLevel: encLevel, + } + } + return alarmTimer{} +} + +func (h *sentPacketHandler) detectLostPathProbes(now time.Time) { + if !h.appDataPackets.history.HasOutstandingPathProbes() { + return + } + lossTime := now.Add(-pathProbePacketLossTimeout) + // RemovePathProbe cannot be called while iterating. + var lostPathProbes []*packet + for p := range h.appDataPackets.history.PathProbes() { + if !p.SendTime.After(lossTime) { + lostPathProbes = append(lostPathProbes, p) + } + } + for _, p := range lostPathProbes { + for _, f := range p.Frames { + f.Handler.OnLost(f.Frame) + } + h.appDataPackets.history.Remove(p.PacketNumber) + h.appDataPackets.history.RemovePathProbe(p.PacketNumber) } } -func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) { pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -618,15 +677,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E lostSendTime := now.Add(-lossDelay) priorInFlight := h.bytesInFlight - return pnSpace.history.Iterate(func(p *packet) (bool, error) { + for p := range pnSpace.history.Packets() { if p.PacketNumber > pnSpace.largestAcked { - return false, nil + break } + isRegularPacket := !p.skippedPacket && !p.isPathProbePacket var packetLost bool if !p.SendTime.After(lostSendTime) { packetLost = true - if !p.skippedPacket { + if isRegularPacket { if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } @@ -636,7 +696,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { packetLost = true - if !p.skippedPacket { + if isRegularPacket { if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) } @@ -654,7 +714,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } if packetLost { pnSpace.history.DeclareLost(p.PacketNumber) - if !p.skippedPacket { + if isRegularPacket { // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) @@ -666,12 +726,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } } } - return true, nil - }) + } } func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { defer h.setLossDetectionTimer(now) + + if h.handshakeConfirmed { + h.detectLostPathProbes(now) + } + earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -681,7 +745,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - return h.detectLostPackets(now, encLevel) + h.detectLostPackets(now, encLevel) + return nil } // PTO @@ -702,11 +767,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error { return nil } - _, encLevel, ok := h.getPTOTimeAndSpace(now) - if !ok { + ptoTime, encLevel := h.getPTOTimeAndSpace(now) + if ptoTime.IsZero() { return nil } - if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation { + ps := h.getPacketNumberSpace(encLevel) + if !ps.history.HasOutstandingPackets() && !ps.history.HasOutstandingPathProbes() && !h.peerCompletedAddressValidation { return nil } h.ptoCount++ @@ -868,24 +934,21 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { func (h *sentPacketHandler) ResetForRetry(now time.Time) { h.bytesInFlight = 0 var firstPacketSendTime time.Time - h.initialPackets.history.Iterate(func(p *packet) (bool, error) { + for p := range h.initialPackets.history.Packets() { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } - if p.declaredLost || p.skippedPacket { - return true, nil - } - h.queueFramesForRetransmission(p) - return true, nil - }) - // All application data packets sent at this point are 0-RTT packets. - // In the case of a Retry, we can assume that the server dropped all of them. - h.appDataPackets.history.Iterate(func(p *packet) (bool, error) { if !p.declaredLost && !p.skippedPacket { h.queueFramesForRetransmission(p) } - return true, nil - }) + } + // All application data packets sent at this point are 0-RTT packets. + // In the case of a Retry, we can assume that the server dropped all of them. + for p := range h.appDataPackets.history.Packets() { + if !p.declaredLost && !p.skippedPacket { + h.queueFramesForRetransmission(p) + } + } // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. // Otherwise, we don't know which Initial the Retry was sent in response to. @@ -913,3 +976,25 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) { } h.ptoCount = 0 } + +func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) { + h.rttStats.ResetForPathMigration() + for p := range h.appDataPackets.history.Packets() { + h.appDataPackets.history.DeclareLost(p.PacketNumber) + if !p.skippedPacket && !p.isPathProbePacket { + h.removeFromBytesInFlight(p) + h.queueFramesForRetransmission(p) + } + } + for p := range h.appDataPackets.history.PathProbes() { + h.appDataPackets.history.RemovePathProbe(p.PacketNumber) + } + h.congestion = congestion.NewCubicSender( + congestion.DefaultClock{}, + h.rttStats, + initialMaxDatagramSize, + true, // use Reno + h.tracer, + ) + h.setLossDetectionTimer(now) +} diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go index 9968df6a..0aabc6d9 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go @@ -2,12 +2,14 @@ package ackhandler import ( "fmt" + "iter" "github.com/quic-go/quic-go/internal/protocol" ) type sentPacketHistory struct { - packets []*packet + packets []*packet + pathProbePackets []*packet numOutstanding int @@ -32,11 +34,11 @@ func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNum panic("non-sequential packet number use") } } + h.highestPacketNumber = pn } func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) { h.checkSequentialPacketNumberUse(pn) - h.highestPacketNumber = pn h.packets = append(h.packets, &packet{ PacketNumber: pn, skippedPacket: true, @@ -45,7 +47,6 @@ func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) { func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) { h.checkSequentialPacketNumberUse(pn) - h.highestPacketNumber = pn if len(h.packets) > 0 { h.packets = append(h.packets, nil) } @@ -53,28 +54,42 @@ func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) { h.checkSequentialPacketNumberUse(p.PacketNumber) - h.highestPacketNumber = p.PacketNumber h.packets = append(h.packets, p) if p.outstanding() { h.numOutstanding++ } } -// Iterate iterates through all packets. -func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error { - for _, p := range h.packets { - if p == nil { - continue - } - cont, err := cb(p) - if err != nil { - return err - } - if !cont { - return nil +func (h *sentPacketHistory) SentPathProbePacket(p *packet) { + h.checkSequentialPacketNumberUse(p.PacketNumber) + h.packets = append(h.packets, &packet{ + PacketNumber: p.PacketNumber, + isPathProbePacket: true, + }) + h.pathProbePackets = append(h.pathProbePackets, p) +} + +func (h *sentPacketHistory) Packets() iter.Seq[*packet] { + return func(yield func(*packet) bool) { + for _, p := range h.packets { + if p == nil { + continue + } + if !yield(p) { + return + } + } + } +} + +func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] { + return func(yield func(*packet) bool) { + for _, p := range h.pathProbePackets { + if !yield(p) { + return + } } } - return nil } // FirstOutstanding returns the first outstanding packet. @@ -90,6 +105,14 @@ func (h *sentPacketHistory) FirstOutstanding() *packet { return nil } +// FirstOutstandingPathProbe returns the first outstanding path probe packet +func (h *sentPacketHistory) FirstOutstandingPathProbe() *packet { + if len(h.pathProbePackets) == 0 { + return nil + } + return h.pathProbePackets[0] +} + func (h *sentPacketHistory) Len() int { return len(h.packets) } @@ -125,6 +148,27 @@ func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error { return nil } +// RemovePathProbe removes a path probe packet. +// It scales O(N), but that's ok, since we don't expect to send many path probe packets. +// It is not valid to call this function in IteratePathProbes. +func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet { + var packetToDelete *packet + idx := -1 + for i, p := range h.pathProbePackets { + if p.PacketNumber == pn { + packetToDelete = p + idx = i + break + } + } + if idx != -1 { + // don't use slices.Delete, because it zeros the deleted element + copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:]) + h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1] + } + return packetToDelete +} + // getIndex gets the index of packet p in the packets slice. func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) { if len(h.packets) == 0 { @@ -145,6 +189,10 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool { return h.numOutstanding > 0 } +func (h *sentPacketHistory) HasOutstandingPathProbes() bool { + return len(h.pathProbePackets) > 0 +} + // delete all nil entries at the beginning of the packets slice func (h *sentPacketHistory) cleanupStart() { for i, p := range h.packets { diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go index c8e6cb33..1a86c675 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go @@ -12,7 +12,6 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" @@ -89,12 +88,13 @@ func NewCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &tls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf cs.allow0RTT = enable0RTT - cs.conn = tls.QUICClient(quicConf) + cs.conn = tls.QUICClient(&tls.QUICConfig{ + TLSConfig: tlsConf, + EnableSessionEvents: true, + }) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) return cs @@ -123,9 +123,13 @@ func NewCryptoSetupServer( ) cs.allow0RTT = allow0RTT - tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket) + tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr) + cs.tlsConf = tlsConf - cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf}) + cs.conn = tls.QUICServer(&tls.QUICConfig{ + TLSConfig: tlsConf, + EnableSessionEvents: true, + }) return cs } @@ -178,11 +182,10 @@ func (h *cryptoSetup) StartHandshake(ctx context.Context) error { } for { ev := h.conn.NextEvent() - done, err := h.handleEvent(ev) - if err != nil { + if err := h.handleEvent(ev); err != nil { return wrapError(err) } - if done { + if ev.Kind == tls.QUICNoEvent { break } } @@ -213,53 +216,78 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev } func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { - if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil { + if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil { return err } for { ev := h.conn.NextEvent() - done, err := h.handleEvent(ev) - if err != nil { + if err := h.handleEvent(ev); err != nil { return err } - if done { + if ev.Kind == tls.QUICNoEvent { return nil } } } -func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { - //nolint:exhaustive - // Go 1.23 added new 0-RTT events, see https://github.com/quic-go/quic-go/issues/4272. - // We will start using these events when dropping support for Go 1.22. +func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) { switch ev.Kind { case tls.QUICNoEvent: - return true, nil + return nil case tls.QUICSetReadSecret: h.setReadKey(ev.Level, ev.Suite, ev.Data) - return false, nil + return nil case tls.QUICSetWriteSecret: h.setWriteKey(ev.Level, ev.Suite, ev.Data) - return false, nil + return nil case tls.QUICTransportParameters: - return false, h.handleTransportParameters(ev.Data) + return h.handleTransportParameters(ev.Data) case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) - return false, nil + return nil case tls.QUICRejectedEarlyData: h.rejected0RTT() - return false, nil + return nil case tls.QUICWriteData: h.writeRecord(ev.Level, ev.Data) - return false, nil + return nil case tls.QUICHandshakeDone: h.handshakeComplete() - return false, nil + return nil + case tls.QUICStoreSession: + if h.perspective == protocol.PerspectiveServer { + panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server") + } + ev.SessionState.Extra = append( + ev.SessionState.Extra, + addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)), + ) + return h.conn.StoreSession(ev.SessionState) + case tls.QUICResumeSession: + var allowEarlyData bool + switch h.perspective { + case protocol.PerspectiveClient: + // for clients, this event occurs when a session ticket is selected + allowEarlyData = h.handleDataFromSessionState( + findSessionStateExtraData(ev.SessionState.Extra), + ev.SessionState.EarlyData, + ) + case protocol.PerspectiveServer: + // for servers, this event occurs when receiving the client's session ticket + allowEarlyData = h.handleSessionTicket( + findSessionStateExtraData(ev.SessionState.Extra), + ev.SessionState.EarlyData, + ) + } + if ev.SessionState.EarlyData { + ev.SessionState.EarlyData = allowEarlyData + } + return nil default: // Unknown events should be ignored. // crypto/tls will ensure that this is safe to do. // See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details. - return false, nil + return nil } } @@ -350,7 +378,10 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte { // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: h.allow0RTT, + Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())}, + }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. @@ -376,9 +407,9 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { // It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT. // Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT: // A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT. -func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool { +func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) { var t sessionTicket - if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil { + if err := t.Unmarshal(data, using0RTT); err != nil { h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error()) return false } @@ -446,7 +477,7 @@ func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra } h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { - h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) + h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } @@ -497,7 +528,7 @@ func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr panic("unexpected write encryption level") } if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { - h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) + h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective) } } diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/conn.go b/vendor/github.com/quic-go/quic-go/internal/handshake/fake_conn.go similarity index 97% rename from vendor/github.com/quic-go/quic-go/internal/qtls/conn.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/fake_conn.go index 6660ac66..54af823b 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/conn.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/fake_conn.go @@ -1,4 +1,4 @@ -package qtls +package handshake import ( "net" diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go b/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go index 0caf1c8e..3da97cd8 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go @@ -8,8 +8,6 @@ import ( ) // hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1. -// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the -// hkdfExpandLabel in the standard library. func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { b := make([]byte, 3, 3+6+len(label)+1+len(context)) binary.BigEndian.PutUint16(b, uint16(length)) diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go index b67f0101..4da517fc 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" "errors" "fmt" "time" @@ -52,3 +53,20 @@ func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error { t.RTT = time.Duration(rtt) * time.Microsecond return nil } + +const extraPrefix = "quic-go1" + +func addSessionStateExtraPrefix(b []byte) []byte { + return append([]byte(extraPrefix), b...) +} + +func findSessionStateExtraData(extras [][]byte) []byte { + prefix := []byte(extraPrefix) + for _, extra := range extras { + if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) { + continue + } + return extra[len(prefix):] + } + return nil +} diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/tls_config.go b/vendor/github.com/quic-go/quic-go/internal/handshake/tls_config.go new file mode 100644 index 00000000..c4c0d22d --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/tls_config.go @@ -0,0 +1,39 @@ +package handshake + +import ( + "crypto/tls" + "net" +) + +func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config { + // Workaround for https://github.com/golang/go/issues/60506. + // This initializes the session tickets _before_ cloning the config. + _, _ = conf.DecryptTicket(nil, tls.ConnectionState{}) + + conf = conf.Clone() + conf.MinVersion = tls.VersionTLS13 + + // The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. + // Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn + // that allows the caller to get the local and the remote address. + if conf.GetConfigForClient != nil { + gcfc := conf.GetConfigForClient + conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + c, err := gcfc(info) + if c != nil { + // we're returning a tls.Config here, so we need to apply this recursively + c = setupConfigForServer(c, localAddr, remoteAddr) + } + return c, err + } + } + if conf.GetCertificate != nil { + gc := conf.GetCertificate + conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} + return gc(info) + } + } + return conf +} diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go b/vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go index 32d38ab1..40aa331a 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go @@ -1,5 +1,10 @@ package protocol +import ( + "crypto/tls" + "fmt" +) + // EncryptionLevel is the encryption level // Default value is Unencrypted type EncryptionLevel uint8 @@ -28,3 +33,33 @@ func (e EncryptionLevel) String() string { } return "unknown" } + +func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel { + switch e { + case EncryptionInitial: + return tls.QUICEncryptionLevelInitial + case EncryptionHandshake: + return tls.QUICEncryptionLevelHandshake + case Encryption1RTT: + return tls.QUICEncryptionLevelApplication + case Encryption0RTT: + return tls.QUICEncryptionLevelEarly + default: + panic(fmt.Sprintf("unexpected encryption level: %s", e)) + } +} + +func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel { + switch e { + case tls.QUICEncryptionLevelInitial: + return EncryptionInitial + case tls.QUICEncryptionLevelHandshake: + return EncryptionHandshake + case tls.QUICEncryptionLevelApplication: + return Encryption1RTT + case tls.QUICEncryptionLevelEarly: + return Encryption0RTT + default: + panic(fmt.Sprintf("unexpect encryption level: %s", e)) + } +} diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go b/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go deleted file mode 100644 index 32a921cd..00000000 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go +++ /dev/null @@ -1,52 +0,0 @@ -package qtls - -import ( - "crypto/tls" - "fmt" - "unsafe" -) - -//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 -var cipherSuitesTLS13 []unsafe.Pointer - -//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13 -var defaultCipherSuitesTLS13 []uint16 - -//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES -var defaultCipherSuitesTLS13NoAES []uint16 - -var cipherSuitesModified bool - -// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls -// such that it only contains the cipher suite with the chosen id. -// The reset function returned resets them back to the original value. -func SetCipherSuite(id uint16) (reset func()) { - if cipherSuitesModified { - panic("cipher suites modified multiple times without resetting") - } - cipherSuitesModified = true - - origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) - origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) - origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) - // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. - switch id { - case tls.TLS_AES_128_GCM_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[:1] - case tls.TLS_CHACHA20_POLY1305_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[1:2] - case tls.TLS_AES_256_GCM_SHA384: - cipherSuitesTLS13 = cipherSuitesTLS13[2:] - default: - panic(fmt.Sprintf("unexpected cipher suite: %d", id)) - } - defaultCipherSuitesTLS13 = []uint16{id} - defaultCipherSuitesTLS13NoAES = []uint16{id} - - return func() { - cipherSuitesTLS13 = origCipherSuitesTLS13 - defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 - defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES - cipherSuitesModified = false - } -} diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go b/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go deleted file mode 100644 index 4acac9e2..00000000 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go +++ /dev/null @@ -1,70 +0,0 @@ -package qtls - -import ( - "crypto/tls" - "sync" -) - -type clientSessionCache struct { - mx sync.Mutex - getData func(earlyData bool) []byte - setData func(data []byte, earlyData bool) (allowEarlyData bool) - wrapped tls.ClientSessionCache -} - -var _ tls.ClientSessionCache = &clientSessionCache{} - -func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) { - c.mx.Lock() - defer c.mx.Unlock() - - if cs == nil { - c.wrapped.Put(key, nil) - return - } - ticket, state, err := cs.ResumptionState() - if err != nil || state == nil { - c.wrapped.Put(key, cs) - return - } - state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData))) - newCS, err := tls.NewResumptionState(ticket, state) - if err != nil { - // It's not clear why this would error. Just save the original state. - c.wrapped.Put(key, cs) - return - } - c.wrapped.Put(key, newCS) -} - -func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { - c.mx.Lock() - defer c.mx.Unlock() - - cs, ok := c.wrapped.Get(key) - if !ok || cs == nil { - return cs, ok - } - ticket, state, err := cs.ResumptionState() - if err != nil { - // It's not clear why this would error. - // Remove the ticket from the session cache, so we don't run into this error over and over again - c.wrapped.Put(key, nil) - return nil, false - } - // restore QUIC transport parameters and RTT stored in state.Extra - if extra := findExtraData(state.Extra); extra != nil { - earlyData := c.setData(extra, state.EarlyData) - if state.EarlyData { - state.EarlyData = earlyData - } - } - session, err := tls.NewResumptionState(ticket, state) - if err != nil { - // It's not clear why this would error. - // Remove the ticket from the session cache, so we don't run into this error over and over again - c.wrapped.Put(key, nil) - return nil, false - } - return session, true -} diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go b/vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go deleted file mode 100644 index cdfe82a2..00000000 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go +++ /dev/null @@ -1,150 +0,0 @@ -package qtls - -import ( - "bytes" - "crypto/tls" - "fmt" - "net" - - "github.com/quic-go/quic-go/internal/protocol" -) - -func SetupConfigForServer( - conf *tls.Config, - localAddr, remoteAddr net.Addr, - getData func() []byte, - handleSessionTicket func([]byte, bool) bool, -) *tls.Config { - // Workaround for https://github.com/golang/go/issues/60506. - // This initializes the session tickets _before_ cloning the config. - _, _ = conf.DecryptTicket(nil, tls.ConnectionState{}) - - conf = conf.Clone() - conf.MinVersion = tls.VersionTLS13 - - // add callbacks to save transport parameters into the session ticket - origWrapSession := conf.WrapSession - conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { - // Add QUIC session ticket - state.Extra = append(state.Extra, addExtraPrefix(getData())) - - if origWrapSession != nil { - return origWrapSession(cs, state) - } - b, err := conf.EncryptTicket(cs, state) - return b, err - } - origUnwrapSession := conf.UnwrapSession - // UnwrapSession might be called multiple times, as the client can use multiple session tickets. - // However, using 0-RTT is only possible with the first session ticket. - // crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello. - var unwrapCount int - conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) { - unwrapCount++ - var state *tls.SessionState - var err error - if origUnwrapSession != nil { - state, err = origUnwrapSession(identity, connState) - } else { - state, err = conf.DecryptTicket(identity, connState) - } - if err != nil || state == nil { - return nil, err - } - - extra := findExtraData(state.Extra) - if extra != nil { - state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1) - } else { - state.EarlyData = false - } - - return state, nil - } - // The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo. - // Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn - // that allows the caller to get the local and the remote address. - if conf.GetConfigForClient != nil { - gcfc := conf.GetConfigForClient - conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - c, err := gcfc(info) - if c != nil { - // We're returning a tls.Config here, so we need to apply this recursively. - c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket) - } - return c, err - } - } - if conf.GetCertificate != nil { - gc := conf.GetCertificate - conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} - return gc(info) - } - } - return conf -} - -func SetupConfigForClient( - qconf *tls.QUICConfig, - getData func(earlyData bool) []byte, - setData func(data []byte, earlyData bool) (allowEarlyData bool), -) { - conf := qconf.TLSConfig - if conf.ClientSessionCache != nil { - origCache := conf.ClientSessionCache - conf.ClientSessionCache = &clientSessionCache{ - wrapped: origCache, - getData: getData, - setData: setData, - } - } -} - -func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel { - switch e { - case protocol.EncryptionInitial: - return tls.QUICEncryptionLevelInitial - case protocol.EncryptionHandshake: - return tls.QUICEncryptionLevelHandshake - case protocol.Encryption1RTT: - return tls.QUICEncryptionLevelApplication - case protocol.Encryption0RTT: - return tls.QUICEncryptionLevelEarly - default: - panic(fmt.Sprintf("unexpected encryption level: %s", e)) - } -} - -func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel { - switch e { - case tls.QUICEncryptionLevelInitial: - return protocol.EncryptionInitial - case tls.QUICEncryptionLevelHandshake: - return protocol.EncryptionHandshake - case tls.QUICEncryptionLevelApplication: - return protocol.Encryption1RTT - case tls.QUICEncryptionLevelEarly: - return protocol.Encryption0RTT - default: - panic(fmt.Sprintf("unexpect encryption level: %s", e)) - } -} - -const extraPrefix = "quic-go1" - -func addExtraPrefix(b []byte) []byte { - return append([]byte(extraPrefix), b...) -} - -func findExtraData(extras [][]byte) []byte { - prefix := []byte(extraPrefix) - for _, extra := range extras { - if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) { - continue - } - return extra[len(prefix):] - } - return nil -} diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index 92fec2e2..0efd8354 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -108,3 +108,12 @@ func (r *RTTStats) SetInitialRTT(t time.Duration) { r.smoothedRTT = t r.latestRTT = t } + +func (r *RTTStats) ResetForPathMigration() { + r.hasMeasurement = false + r.minRTT = 0 + r.latestRTT = 0 + r.smoothedRTT = 0 + r.meanDeviation = 0 + // max_ack_delay remains valid +} diff --git a/vendor/github.com/quic-go/quic-go/logging/types.go b/vendor/github.com/quic-go/quic-go/logging/types.go index 0d79b0a9..65da3559 100644 --- a/vendor/github.com/quic-go/quic-go/logging/types.go +++ b/vendor/github.com/quic-go/quic-go/logging/types.go @@ -63,9 +63,11 @@ type TimerType uint8 const ( // TimerTypeACK is the timer type for the early retransmit timer - TimerTypeACK TimerType = iota + TimerTypeACK TimerType = iota + 1 // TimerTypePTO is the timer type for the PTO retransmit timer TimerTypePTO + // TimerTypePathProbe is the timer type for the path probe retransmit timer + TimerTypePathProbe ) // TimeoutReason is the reason why a connection is closed diff --git a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go index ee636a6d..096eba14 100644 --- a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go +++ b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go @@ -17,6 +17,7 @@ type mtuDiscoverer interface { ShouldSendProbe(now time.Time) bool CurrentSize() protocol.ByteCount GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount) + Reset(now time.Time, start, max protocol.ByteCount) } const ( @@ -88,7 +89,6 @@ const ( type mtuFinder struct { lastProbeTime time.Time - mtuIncreased func(protocol.ByteCount) rttStats *utils.RTTStats @@ -99,6 +99,11 @@ type mtuFinder struct { lost [maxLostMTUProbes]protocol.ByteCount lastProbeWasLost bool + // The generation is used to ignore ACKs / losses for probe packets sent before a reset. + // Resets happen when the connection is migrated to a new path. + // We're therefore not concerned about overflows of this counter. + generation uint8 + tracer *logging.ConnectionTracer } @@ -107,16 +112,19 @@ var _ mtuDiscoverer = &mtuFinder{} func newMTUDiscoverer( rttStats *utils.RTTStats, start, max protocol.ByteCount, - mtuIncreased func(protocol.ByteCount), tracer *logging.ConnectionTracer, ) *mtuFinder { f := &mtuFinder{ - inFlight: protocol.InvalidByteCount, - min: start, - rttStats: rttStats, - mtuIncreased: mtuIncreased, - tracer: tracer, + inFlight: protocol.InvalidByteCount, + rttStats: rttStats, + tracer: tracer, } + f.init(start, max) + return f +} + +func (f *mtuFinder) init(start, max protocol.ByteCount) { + f.min = start for i := range f.lost { if i == 0 { f.lost[i] = max @@ -124,7 +132,6 @@ func newMTUDiscoverer( } f.lost[i] = protocol.InvalidByteCount } - return f } func (f *mtuFinder) done() bool { @@ -165,7 +172,7 @@ func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount f.inFlight = size return ackhandler.Frame{ Frame: &wire.PingFrame{}, - Handler: &mtuFinderAckHandler{f}, + Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation}, }, size } @@ -173,13 +180,26 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount { return f.min } +func (f *mtuFinder) Reset(now time.Time, start, max protocol.ByteCount) { + f.generation++ + f.lastProbeTime = now + f.lastProbeWasLost = false + f.inFlight = protocol.InvalidByteCount + f.init(start, max) +} + type mtuFinderAckHandler struct { *mtuFinder + generation uint8 } var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { + if h.generation != h.mtuFinder.generation { + // ACK for probe sent before reset + return + } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnAcked callback called although there's no MTU probe packet in flight") @@ -207,10 +227,13 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { if h.tracer != nil && h.tracer.UpdatedMTU != nil { h.tracer.UpdatedMTU(size, h.done()) } - h.mtuIncreased(size) } func (h *mtuFinderAckHandler) OnLost(wire.Frame) { + if h.generation != h.mtuFinder.generation { + // probe sent before reset received + return + } size := h.inFlight if size == protocol.InvalidByteCount { panic("OnLost callback called although there's no MTU probe packet in flight") diff --git a/vendor/github.com/quic-go/quic-go/packet_packer.go b/vendor/github.com/quic-go/quic-go/packet_packer.go index 7724b503..720f1958 100644 --- a/vendor/github.com/quic-go/quic-go/packet_packer.go +++ b/vendor/github.com/quic-go/quic-go/packet_packer.go @@ -22,9 +22,10 @@ type packer interface { PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error) + MaybePackPTOProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackPathProbePacket(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) SetToken([]byte) @@ -57,6 +58,7 @@ type shortHeaderPacket struct { Ack *wire.AckFrame Length protocol.ByteCount IsPathMTUProbePacket bool + IsPathProbePacket bool // used for logging DestConnID protocol.ConnectionID @@ -269,17 +271,17 @@ func (p *packetPacker) packConnectionClose( if sealers[i] == nil { continue } - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) - } if encLevel == protocol.Encryption1RTT { - shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v) if err != nil { return nil, err } packet.shortHdrPacket = &shp } else { + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) + } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) if err != nil { return nil, err @@ -707,7 +709,7 @@ func (p *packetPacker) composeNextPacket( return pl } -func (p *packetPacker) MaybePackProbePacket( +func (p *packetPacker) MaybePackPTOProbePacket( encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, now time.Time, @@ -792,6 +794,26 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B return packet, buffer, err } +func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, f ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + buf := getPacketBuffer() + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return shortHeaderPacket{}, nil, err + } + payload := payload{ + frames: []ackhandler.Frame{f}, + length: f.Frame.Length(v), + } + padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) + packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v) + if err != nil { + return shortHeaderPacket{}, nil, err + } + packet.IsPathProbePacket = true + return packet, buf, err +} + func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ diff --git a/vendor/github.com/quic-go/quic-go/path_manager.go b/vendor/github.com/quic-go/quic-go/path_manager.go new file mode 100644 index 00000000..6d940921 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/path_manager.go @@ -0,0 +1,145 @@ +package quic + +import ( + "crypto/rand" + "net" + + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" +) + +type pathID int64 + +const maxPaths = 3 + +type path struct { + addr net.Addr + pathChallenge [8]byte + validated bool + rcvdNonProbing bool +} + +type pathManager struct { + nextPathID pathID + paths map[pathID]*path + + getConnID func(pathID) (_ protocol.ConnectionID, ok bool) + retireConnID func(pathID) + + logger utils.Logger +} + +func newPathManager( + getConnID func(pathID) (_ protocol.ConnectionID, ok bool), + retireConnID func(pathID), + logger utils.Logger, +) *pathManager { + return &pathManager{ + paths: make(map[pathID]*path), + getConnID: getConnID, + retireConnID: retireConnID, + logger: logger, + } +} + +// Returns a path challenge frame if one should be sent. +// May return nil. +func (pm *pathManager) HandlePacket(p receivedPacket, isNonProbing bool) (_ protocol.ConnectionID, _ ackhandler.Frame, shouldSwitch bool) { + for _, path := range pm.paths { + if addrsEqual(path.addr, p.remoteAddr) { + // already sent a PATH_CHALLENGE for this path + if isNonProbing { + path.rcvdNonProbing = true + } + if pm.logger.Debug() { + pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", p.remoteAddr, path.validated) + } + return protocol.ConnectionID{}, ackhandler.Frame{}, path.validated && path.rcvdNonProbing + } + } + + if len(pm.paths) >= maxPaths { + if pm.logger.Debug() { + pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", p.remoteAddr, len(pm.paths)) + } + return protocol.ConnectionID{}, ackhandler.Frame{}, false + } + + // previously unseen path, initiate path validation by sending a PATH_CHALLENGE + connID, ok := pm.getConnID(pm.nextPathID) + if !ok { + pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", p.remoteAddr) + return protocol.ConnectionID{}, ackhandler.Frame{}, false + } + var b [8]byte + rand.Read(b[:]) + pm.paths[pm.nextPathID] = &path{ + addr: p.remoteAddr, + pathChallenge: b, + rcvdNonProbing: isNonProbing, + } + pm.nextPathID++ + frame := ackhandler.Frame{ + Frame: &wire.PathChallengeFrame{Data: b}, + Handler: (*pathManagerAckHandler)(pm), + } + pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", p.remoteAddr) + return connID, frame, false +} + +func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) { + for _, p := range pm.paths { + if f.Data == p.pathChallenge { + // path validated + p.validated = true + pm.logger.Debugf("path %s validated", p.addr) + break + } + } +} + +// SwitchToPath is called when the connection switches to a new path +func (pm *pathManager) SwitchToPath(addr net.Addr) { + // retire all other paths + for id := range pm.paths { + if addrsEqual(pm.paths[id].addr, addr) { + pm.logger.Debugf("switching to path %d (%s)", id, addr) + continue + } + pm.retireConnID(id) + } + clear(pm.paths) +} + +type pathManagerAckHandler pathManager + +var _ ackhandler.FrameHandler = &pathManagerAckHandler{} + +// Acknowledging the frame doesn't validate the path, only receiving the PATH_RESPONSE does. +func (pm *pathManagerAckHandler) OnAcked(f wire.Frame) {} + +func (pm *pathManagerAckHandler) OnLost(f wire.Frame) { + // TODO: retransmit the packet the first time it is lost + pc := f.(*wire.PathChallengeFrame) + for id, path := range pm.paths { + if path.pathChallenge == pc.Data { + delete(pm.paths, id) + pm.retireConnID(id) + break + } + } +} + +func addrsEqual(addr1, addr2 net.Addr) bool { + if addr1 == nil || addr2 == nil { + return false + } + a1, ok1 := addr1.(*net.UDPAddr) + a2, ok2 := addr2.(*net.UDPAddr) + if ok1 && ok2 { + return a1.IP.Equal(a2.IP) && a1.Port == a2.Port + } + return addr1.String() == addr2.String() +} diff --git a/vendor/github.com/quic-go/quic-go/send_conn.go b/vendor/github.com/quic-go/quic-go/send_conn.go index 498ed112..402520c6 100644 --- a/vendor/github.com/quic-go/quic-go/send_conn.go +++ b/vendor/github.com/quic-go/quic-go/send_conn.go @@ -2,6 +2,7 @@ package quic import ( "net" + "sync/atomic" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" @@ -10,22 +11,29 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { Write(b []byte, gsoSize uint16, ecn protocol.ECN) error + WriteTo([]byte, net.Addr) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr + ChangeRemoteAddr(addr net.Addr, info packetInfo) capabilities() connCapabilities } +type remoteAddrInfo struct { + addr net.Addr + oob []byte +} + type sconn struct { rawConn - localAddr net.Addr - remoteAddr net.Addr + localAddr net.Addr + + remoteAddrInfo atomic.Pointer[remoteAddrInfo] logger utils.Logger - packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool // Used to catch the error sometimes returned by the first sendmsg call on Linux, @@ -49,22 +57,26 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] - return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - packetInfoOOB: oob, - logger: logger, + sc := &sconn{ + rawConn: c, + localAddr: localAddr, + logger: logger, } + sc.remoteAddrInfo.Store(&remoteAddrInfo{ + addr: remote, + oob: oob, + }) + return sc } func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { - err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) + ai := c.remoteAddrInfo.Load() + err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true if c.logger.Debug() { - c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + c.logger.Debugf("GSO failed when sending to %s", ai.addr) } // send out the packets one by one for len(p) > 0 { @@ -72,7 +84,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { if l > int(gsoSize) { l = int(gsoSize) } - if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { + if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil { return err } p = p[l:] @@ -91,6 +103,11 @@ func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16, return err } +func (c *sconn) WriteTo(b []byte, addr net.Addr) error { + _, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) + return err +} + func (c *sconn) capabilities() connCapabilities { capabilities := c.rawConn.capabilities() if capabilities.GSO { @@ -99,5 +116,12 @@ func (c *sconn) capabilities() connCapabilities { return capabilities } -func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) { + c.remoteAddrInfo.Store(&remoteAddrInfo{ + addr: addr, + oob: info.OOB(), + }) +} + +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr } func (c *sconn) LocalAddr() net.Addr { return c.localAddr } diff --git a/vendor/github.com/quic-go/quic-go/send_queue.go b/vendor/github.com/quic-go/quic-go/send_queue.go index bde02334..d19762be 100644 --- a/vendor/github.com/quic-go/quic-go/send_queue.go +++ b/vendor/github.com/quic-go/quic-go/send_queue.go @@ -1,9 +1,14 @@ package quic -import "github.com/quic-go/quic-go/internal/protocol" +import ( + "net" + + "github.com/quic-go/quic-go/internal/protocol" +) type sender interface { Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) + SendProbe(*packetBuffer, net.Addr) Run() error WouldBlock() bool Available() <-chan struct{} @@ -57,6 +62,10 @@ func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { } } +func (h *sendQueue) SendProbe(p *packetBuffer, addr net.Addr) { + h.conn.WriteTo(p.Data, addr) +} + func (h *sendQueue) WouldBlock() bool { return len(h.queue) == sendQueueCapacity } diff --git a/vendor/modules.txt b/vendor/modules.txt index 1337f186..db8a4148 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -107,8 +107,8 @@ github.com/powerman/deepequal # github.com/quic-go/qpack v0.5.1 ## explicit; go 1.22 github.com/quic-go/qpack -# github.com/quic-go/quic-go v0.49.0 -## explicit; go 1.22 +# github.com/quic-go/quic-go v0.50.0 +## explicit; go 1.23 github.com/quic-go/quic-go github.com/quic-go/quic-go/http3 github.com/quic-go/quic-go/internal/ackhandler @@ -117,7 +117,6 @@ github.com/quic-go/quic-go/internal/flowcontrol github.com/quic-go/quic-go/internal/handshake github.com/quic-go/quic-go/internal/protocol github.com/quic-go/quic-go/internal/qerr -github.com/quic-go/quic-go/internal/qtls github.com/quic-go/quic-go/internal/utils github.com/quic-go/quic-go/internal/utils/linkedlist github.com/quic-go/quic-go/internal/utils/ringbuffer