Update quic-go

This commit is contained in:
Frank Denis 2025-02-21 18:11:29 +01:00
parent f49196c6e8
commit 41bc703873
26 changed files with 1097 additions and 637 deletions

2
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/miekg/dns v1.1.63
github.com/opencoff/go-sieve v0.2.1
github.com/powerman/check v1.8.0
github.com/quic-go/quic-go v0.49.0
github.com/quic-go/quic-go v0.50.0
golang.org/x/crypto v0.33.0
golang.org/x/net v0.35.0
golang.org/x/sys v0.30.0

4
go.sum
View file

@ -75,8 +75,8 @@ github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+
github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.49.0 h1:w5iJHXwHxs1QxyBv1EHKuC50GX5to8mJAxvtnttJp94=
github.com/quic-go/quic-go v0.49.0/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo=
github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=

View file

@ -19,6 +19,9 @@ type newConnID struct {
type connIDManager struct {
queue list.List[newConnID]
highestProbingID uint64
pathProbing map[pathID]newConnID // initialized lazily
handshakeComplete bool
activeSequenceNumber uint64
highestRetired uint64
@ -76,13 +79,23 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
}
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: f.SequenceNumber,
})
return nil
}
if f.RetirePriorTo != 0 && h.pathProbing != nil {
for id, entry := range h.pathProbing {
if entry.SequenceNumber < f.RetirePriorTo {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
delete(h.pathProbing, id)
}
}
}
// Retire elements in the queue.
// Doesn't retire the active connection ID.
if f.RetirePriorTo > h.highestRetired {
@ -225,6 +238,50 @@ func (h *connIDManager) SetHandshakeComplete() {
h.handshakeComplete = true
}
// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one).
// Once a connection ID is allocated for a path, it cannot be used for a different path.
// When called with the same pathID, it will return the same connection ID,
// unless the peer requested that this connection ID be retired.
func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return protocol.ConnectionID{}, true
}
if h.pathProbing == nil {
h.pathProbing = make(map[pathID]newConnID)
}
entry, ok := h.pathProbing[id]
if ok {
return entry.ConnectionID, true
}
if h.queue.Len() == 0 {
return protocol.ConnectionID{}, false
}
front := h.queue.Remove(h.queue.Front())
h.pathProbing[id] = front
h.highestProbingID = front.SequenceNumber
return front.ConnectionID, true
}
func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return
}
entry, ok := h.pathProbing[pathID]
if !ok {
return
}
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
delete(h.pathProbing, pathID)
}
// Using the connIDManager after it has been closed can have disastrous effects:
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
// leading to a memory leak of the connection struct.

File diff suppressed because it is too large Load diff

View file

@ -684,7 +684,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
if logger == nil {
logger = slog.Default()
}
logger.Error("http: panic serving", "arg", p, "trace", string(buf))
logger.Error("http3: panic serving", "arg", p, "trace", string(buf))
}
}()
handler.ServeHTTP(r, req)
@ -694,18 +694,6 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
return
}
// only write response when there is no panic
if !panicked {
// response not written to the client yet, set Content-Length
if !r.headerWritten {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
}
r.Flush()
r.flushTrailers()
}
// abort the stream when there is a panic
if panicked {
str.CancelRead(quic.StreamErrorCode(ErrCodeInternalError))
@ -713,9 +701,17 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
return
}
// response not written to the client yet, set Content-Length
if !r.headerWritten {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
}
r.Flush()
r.flushTrailers()
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
str.Close()
}

View file

@ -10,7 +10,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket, isPathProbePacket bool)
// ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
@ -34,6 +34,8 @@ type SentPacketHandler interface {
GetLossDetectionTimeout() time.Time
OnLossDetectionTimeout(now time.Time) error
MigratedPath(now time.Time, initialMaxPacketSize protocol.ByteCount)
}
type sentPacketTracker interface {

View file

@ -22,10 +22,11 @@ type packet struct {
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
isPathProbePacket bool
}
func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket && !p.isPathProbePacket
}
var packetPool = sync.Pool{New: func() any { return &packet{} }}

View file

@ -27,6 +27,9 @@ const (
maxPTODuration = 60 * time.Second
)
// Path probe packets are declared lost after this time.
const pathProbePacketLossTimeout = time.Second
type packetNumberSpace struct {
history sentPacketHistory
pns packetNumberGenerator
@ -174,10 +177,9 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) (bool, error) {
for p := range pnSpace.history.Packets() {
h.removeFromBytesInFlight(p)
return true, nil
})
}
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
@ -194,14 +196,13 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now t
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
for p := range h.appDataPackets.history.Packets() {
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false, nil
break
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
return true, nil
})
}
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
@ -249,11 +250,12 @@ func (h *sentPacketHandler) SentPacket(
ecn protocol.ECN,
size protocol.ByteCount,
isPathMTUProbePacket bool,
isPathProbePacket bool,
) {
h.bytesSent += size
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
if h.logger.Debug() && (pnSpace.history.HasOutstandingPackets() || pnSpace.history.HasOutstandingPathProbes()) {
for p := max(0, pnSpace.largestSent+1); p < pn; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
@ -262,6 +264,18 @@ func (h *sentPacketHandler) SentPacket(
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isPathProbePacket {
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.Frames = frames
p.isPathProbePacket = true
pnSpace.history.SentPathProbePacket(p)
h.setLossDetectionTimer(t)
return
}
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
@ -341,7 +355,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
}
// update the RTT, if the largest acked is newly acknowledged
if len(ackedPackets) > 0 {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() && !p.isPathProbePacket {
// don't use the ack delay for Initial and Handshake packets
var ackDelay time.Duration
if encLevel == protocol.Encryption1RTT {
@ -365,8 +379,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked)
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err
h.detectLostPackets(rcvTime, encLevel)
if encLevel == protocol.Encryption1RTT {
h.detectLostPathProbes(rcvTime)
}
var acked1RTTPacket bool
for _, p := range ackedPackets {
@ -377,7 +392,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
acked1RTTPacket = true
}
h.removeFromBytesInFlight(p)
putPacket(p)
if !p.isPathProbePacket {
putPacket(p)
}
}
// After this point, we must not use ackedPackets any longer!
// We've already returned the buffers.
@ -411,14 +428,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
// Ignore packets below the lowest acked
for p := range pnSpace.history.Packets() {
// ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
continue
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
break
}
if ack.HasMissingRanges() {
@ -430,21 +446,28 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
return true, nil
continue
}
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
return false, &qerr.TransportError{
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
}
if p.isPathProbePacket {
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
if probePacket == nil {
panic(fmt.Sprintf("path probe doesn't exist: %d", p.PacketNumber))
}
h.ackedPackets = append(h.ackedPackets, probePacket)
continue
}
h.ackedPackets = append(h.ackedPackets, p)
return true, nil
})
}
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
@ -475,8 +498,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
}
}
return h.ackedPackets, err
return h.ackedPackets, nil
}
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
@ -507,7 +529,7 @@ func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration
}
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel) {
// We only send application data probe packets once the handshake is confirmed,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
@ -516,32 +538,35 @@ func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, en
}
t := now.Add(h.getScaledPTO(false))
if h.initialPackets != nil {
return t, protocol.EncryptionInitial, true
return t, protocol.EncryptionInitial
}
return t, protocol.EncryptionHandshake, true
return t, protocol.EncryptionHandshake
}
if h.initialPackets != nil {
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() &&
!h.initialPackets.lastAckElicitingPacketTime.IsZero() {
encLevel = protocol.EncryptionInitial
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
pto = t.Add(h.getScaledPTO(false))
}
}
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() &&
!h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.EncryptionHandshake
}
}
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
if h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets() &&
!h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.Encryption1RTT
}
}
return pto, encLevel, true
return pto, encLevel
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
@ -573,8 +598,8 @@ func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) {
func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
// cancel the alarm if no packets are outstanding
if h.peerCompletedAddressValidation &&
!h.hasOutstandingCryptoPackets() && !h.appDataPackets.history.HasOutstandingPackets() {
if h.peerCompletedAddressValidation && !h.hasOutstandingCryptoPackets() &&
!h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.history.HasOutstandingPathProbes() {
return alarmTimer{}
}
@ -583,28 +608,62 @@ func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
return alarmTimer{}
}
var pathProbeLossTime time.Time
if h.appDataPackets.history.HasOutstandingPathProbes() {
if p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil {
pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout)
}
}
// early retransmit timer or time loss detection
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() {
if !lossTime.IsZero() && (pathProbeLossTime.IsZero() || lossTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: lossTime,
TimerType: logging.TimerTypeACK,
EncryptionLevel: encLevel,
}
}
ptoTime, encLevel, ok := h.getPTOTimeAndSpace(now)
if !ok {
return alarmTimer{}
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if !ptoTime.IsZero() && (pathProbeLossTime.IsZero() || ptoTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: ptoTime,
TimerType: logging.TimerTypePTO,
EncryptionLevel: encLevel,
}
}
return alarmTimer{
Time: ptoTime,
TimerType: logging.TimerTypePTO,
EncryptionLevel: encLevel,
if !pathProbeLossTime.IsZero() {
return alarmTimer{
Time: pathProbeLossTime,
TimerType: logging.TimerTypePathProbe,
EncryptionLevel: encLevel,
}
}
return alarmTimer{}
}
func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
if !h.appDataPackets.history.HasOutstandingPathProbes() {
return
}
lossTime := now.Add(-pathProbePacketLossTimeout)
// RemovePathProbe cannot be called while iterating.
var lostPathProbes []*packet
for p := range h.appDataPackets.history.PathProbes() {
if !p.SendTime.After(lossTime) {
lostPathProbes = append(lostPathProbes, p)
}
}
for _, p := range lostPathProbes {
for _, f := range p.Frames {
f.Handler.OnLost(f.Frame)
}
h.appDataPackets.history.Remove(p.PacketNumber)
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
@ -618,15 +677,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *packet) (bool, error) {
for p := range pnSpace.history.Packets() {
if p.PacketNumber > pnSpace.largestAcked {
return false, nil
break
}
isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
var packetLost bool
if !p.SendTime.After(lostSendTime) {
packetLost = true
if !p.skippedPacket {
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
}
@ -636,7 +696,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true
if !p.skippedPacket {
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
}
@ -654,7 +714,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
if packetLost {
pnSpace.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket {
if isRegularPacket {
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
@ -666,12 +726,16 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
}
}
}
return true, nil
})
}
}
func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
defer h.setLossDetectionTimer(now)
if h.handshakeConfirmed {
h.detectLostPathProbes(now)
}
earliestLossTime, encLevel := h.getLossTimeAndSpace()
if !earliestLossTime.IsZero() {
if h.logger.Debug() {
@ -681,7 +745,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
}
// Early retransmit or time loss detection
return h.detectLostPackets(now, encLevel)
h.detectLostPackets(now, encLevel)
return nil
}
// PTO
@ -702,11 +767,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
return nil
}
_, encLevel, ok := h.getPTOTimeAndSpace(now)
if !ok {
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if ptoTime.IsZero() {
return nil
}
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
ps := h.getPacketNumberSpace(encLevel)
if !ps.history.HasOutstandingPackets() && !ps.history.HasOutstandingPathProbes() && !h.peerCompletedAddressValidation {
return nil
}
h.ptoCount++
@ -868,24 +934,21 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
func (h *sentPacketHandler) ResetForRetry(now time.Time) {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
for p := range h.initialPackets.history.Packets() {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost || p.skippedPacket {
return true, nil
}
h.queueFramesForRetransmission(p)
return true, nil
})
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true, nil
})
}
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
for p := range h.appDataPackets.history.Packets() {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
}
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
@ -913,3 +976,25 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) {
}
h.ptoCount = 0
}
func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
h.rttStats.ResetForPathMigration()
for p := range h.appDataPackets.history.Packets() {
h.appDataPackets.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket && !p.isPathProbePacket {
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
}
}
for p := range h.appDataPackets.history.PathProbes() {
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
h.congestion = congestion.NewCubicSender(
congestion.DefaultClock{},
h.rttStats,
initialMaxDatagramSize,
true, // use Reno
h.tracer,
)
h.setLossDetectionTimer(now)
}

View file

@ -2,12 +2,14 @@ package ackhandler
import (
"fmt"
"iter"
"github.com/quic-go/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packets []*packet
packets []*packet
pathProbePackets []*packet
numOutstanding int
@ -32,11 +34,11 @@ func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNum
panic("non-sequential packet number use")
}
}
h.highestPacketNumber = pn
}
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
@ -45,7 +47,6 @@ func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
@ -53,28 +54,42 @@ func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber)
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.highestPacketNumber = p.PacketNumber
h.packets = append(h.packets, p)
if p.outstanding() {
h.numOutstanding++
}
}
// Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
for _, p := range h.packets {
if p == nil {
continue
}
cont, err := cb(p)
if err != nil {
return err
}
if !cont {
return nil
func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.packets = append(h.packets, &packet{
PacketNumber: p.PacketNumber,
isPathProbePacket: true,
})
h.pathProbePackets = append(h.pathProbePackets, p)
}
func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.packets {
if p == nil {
continue
}
if !yield(p) {
return
}
}
}
}
func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.pathProbePackets {
if !yield(p) {
return
}
}
}
return nil
}
// FirstOutstanding returns the first outstanding packet.
@ -90,6 +105,14 @@ func (h *sentPacketHistory) FirstOutstanding() *packet {
return nil
}
// FirstOutstandingPathProbe returns the first outstanding path probe packet
func (h *sentPacketHistory) FirstOutstandingPathProbe() *packet {
if len(h.pathProbePackets) == 0 {
return nil
}
return h.pathProbePackets[0]
}
func (h *sentPacketHistory) Len() int {
return len(h.packets)
}
@ -125,6 +148,27 @@ func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
return nil
}
// RemovePathProbe removes a path probe packet.
// It scales O(N), but that's ok, since we don't expect to send many path probe packets.
// It is not valid to call this function in IteratePathProbes.
func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet {
var packetToDelete *packet
idx := -1
for i, p := range h.pathProbePackets {
if p.PacketNumber == pn {
packetToDelete = p
idx = i
break
}
}
if idx != -1 {
// don't use slices.Delete, because it zeros the deleted element
copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:])
h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1]
}
return packetToDelete
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
@ -145,6 +189,10 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
func (h *sentPacketHistory) HasOutstandingPathProbes() bool {
return len(h.pathProbePackets) > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {

View file

@ -12,7 +12,6 @@ import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
@ -89,12 +88,13 @@ func NewCryptoSetupClient(
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.allow0RTT = enable0RTT
cs.conn = tls.QUICClient(quicConf)
cs.conn = tls.QUICClient(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs
@ -123,9 +123,13 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
cs.conn = tls.QUICServer(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
return cs
}
@ -178,11 +182,10 @@ func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
if err := h.handleEvent(ev); err != nil {
return wrapError(err)
}
if done {
if ev.Kind == tls.QUICNoEvent {
break
}
}
@ -213,53 +216,78 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev
}
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
if err := h.handleEvent(ev); err != nil {
return err
}
if done {
if ev.Kind == tls.QUICNoEvent {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
//nolint:exhaustive
// Go 1.23 added new 0-RTT events, see https://github.com/quic-go/quic-go/issues/4272.
// We will start using these events when dropping support for Go 1.22.
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) {
switch ev.Kind {
case tls.QUICNoEvent:
return true, nil
return nil
case tls.QUICSetReadSecret:
h.setReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
return nil
case tls.QUICSetWriteSecret:
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
return nil
case tls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
return h.handleTransportParameters(ev.Data)
case tls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return false, nil
return nil
case tls.QUICRejectedEarlyData:
h.rejected0RTT()
return false, nil
return nil
case tls.QUICWriteData:
h.writeRecord(ev.Level, ev.Data)
return false, nil
return nil
case tls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
return nil
case tls.QUICStoreSession:
if h.perspective == protocol.PerspectiveServer {
panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server")
}
ev.SessionState.Extra = append(
ev.SessionState.Extra,
addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)),
)
return h.conn.StoreSession(ev.SessionState)
case tls.QUICResumeSession:
var allowEarlyData bool
switch h.perspective {
case protocol.PerspectiveClient:
// for clients, this event occurs when a session ticket is selected
allowEarlyData = h.handleDataFromSessionState(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
case protocol.PerspectiveServer:
// for servers, this event occurs when receiving the client's session ticket
allowEarlyData = h.handleSessionTicket(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
}
if ev.SessionState.EarlyData {
ev.SessionState.EarlyData = allowEarlyData
}
return nil
default:
// Unknown events should be ignored.
// crypto/tls will ensure that this is safe to do.
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
return false, nil
return nil
}
}
@ -350,7 +378,10 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())},
}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
@ -376,9 +407,9 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT.
// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT:
// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT.
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
if err := t.Unmarshal(data, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
@ -446,7 +477,7 @@ func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra
}
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
@ -497,7 +528,7 @@ func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr
panic("unexpected write encryption level")
}
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective)
}
}

View file

@ -1,4 +1,4 @@
package qtls
package handshake
import (
"net"

View file

@ -8,8 +8,6 @@ import (
)
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))

View file

@ -1,6 +1,7 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
@ -52,3 +53,20 @@ func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
t.RTT = time.Duration(rtt) * time.Microsecond
return nil
}
const extraPrefix = "quic-go1"
func addSessionStateExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findSessionStateExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View file

@ -0,0 +1,39 @@
package handshake
import (
"crypto/tls"
"net"
)
func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// we're returning a tls.Config here, so we need to apply this recursively
c = setupConfigForServer(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}

View file

@ -1,5 +1,10 @@
package protocol
import (
"crypto/tls"
"fmt"
)
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel uint8
@ -28,3 +33,33 @@ func (e EncryptionLevel) String() string {
}
return "unknown"
}
func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel {
switch e {
case EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}

View file

@ -1,52 +0,0 @@
package qtls
import (
"crypto/tls"
"fmt"
"unsafe"
)
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}

View file

@ -1,70 +0,0 @@
package qtls
import (
"crypto/tls"
"sync"
)
type clientSessionCache struct {
mx sync.Mutex
getData func(earlyData bool) []byte
setData func(data []byte, earlyData bool) (allowEarlyData bool)
wrapped tls.ClientSessionCache
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
c.mx.Lock()
defer c.mx.Unlock()
if cs == nil {
c.wrapped.Put(key, nil)
return
}
ticket, state, err := cs.ResumptionState()
if err != nil || state == nil {
c.wrapped.Put(key, cs)
return
}
state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData)))
newCS, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error. Just save the original state.
c.wrapped.Put(key, cs)
return
}
c.wrapped.Put(key, newCS)
}
func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
c.mx.Lock()
defer c.mx.Unlock()
cs, ok := c.wrapped.Get(key)
if !ok || cs == nil {
return cs, ok
}
ticket, state, err := cs.ResumptionState()
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
// restore QUIC transport parameters and RTT stored in state.Extra
if extra := findExtraData(state.Extra); extra != nil {
earlyData := c.setData(extra, state.EarlyData)
if state.EarlyData {
state.EarlyData = earlyData
}
}
session, err := tls.NewResumptionState(ticket, state)
if err != nil {
// It's not clear why this would error.
// Remove the ticket from the session cache, so we don't run into this error over and over again
c.wrapped.Put(key, nil)
return nil, false
}
return session, true
}

View file

@ -1,150 +0,0 @@
package qtls
import (
"bytes"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC session ticket
state.Extra = append(state.Extra, addExtraPrefix(getData()))
if origWrapSession != nil {
return origWrapSession(cs, state)
}
b, err := conf.EncryptTicket(cs, state)
return b, err
}
origUnwrapSession := conf.UnwrapSession
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
// However, using 0-RTT is only possible with the first session ticket.
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
var unwrapCount int
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
unwrapCount++
var state *tls.SessionState
var err error
if origUnwrapSession != nil {
state, err = origUnwrapSession(identity, connState)
} else {
state, err = conf.DecryptTicket(identity, connState)
}
if err != nil || state == nil {
return nil, err
}
extra := findExtraData(state.Extra)
if extra != nil {
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
} else {
state.EarlyData = false
}
return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
func SetupConfigForClient(
qconf *tls.QUICConfig,
getData func(earlyData bool) []byte,
setData func(data []byte, earlyData bool) (allowEarlyData bool),
) {
conf := qconf.TLSConfig
if conf.ClientSessionCache != nil {
origCache := conf.ClientSessionCache
conf.ClientSessionCache = &clientSessionCache{
wrapped: origCache,
getData: getData,
setData: setData,
}
}
}
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
switch e {
case protocol.EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case protocol.EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case protocol.Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case protocol.Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return protocol.EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return protocol.EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return protocol.Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return protocol.Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}
const extraPrefix = "quic-go1"
func addExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}

View file

@ -108,3 +108,12 @@ func (r *RTTStats) SetInitialRTT(t time.Duration) {
r.smoothedRTT = t
r.latestRTT = t
}
func (r *RTTStats) ResetForPathMigration() {
r.hasMeasurement = false
r.minRTT = 0
r.latestRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
// max_ack_delay remains valid
}

View file

@ -63,9 +63,11 @@ type TimerType uint8
const (
// TimerTypeACK is the timer type for the early retransmit timer
TimerTypeACK TimerType = iota
TimerTypeACK TimerType = iota + 1
// TimerTypePTO is the timer type for the PTO retransmit timer
TimerTypePTO
// TimerTypePathProbe is the timer type for the path probe retransmit timer
TimerTypePathProbe
)
// TimeoutReason is the reason why a connection is closed

View file

@ -17,6 +17,7 @@ type mtuDiscoverer interface {
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount)
Reset(now time.Time, start, max protocol.ByteCount)
}
const (
@ -88,7 +89,6 @@ const (
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
@ -99,6 +99,11 @@ type mtuFinder struct {
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
// The generation is used to ignore ACKs / losses for probe packets sent before a reset.
// Resets happen when the connection is migrated to a new path.
// We're therefore not concerned about overflows of this counter.
generation uint8
tracer *logging.ConnectionTracer
}
@ -107,16 +112,19 @@ var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
min: start,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
inFlight: protocol.InvalidByteCount,
rttStats: rttStats,
tracer: tracer,
}
f.init(start, max)
return f
}
func (f *mtuFinder) init(start, max protocol.ByteCount) {
f.min = start
for i := range f.lost {
if i == 0 {
f.lost[i] = max
@ -124,7 +132,6 @@ func newMTUDiscoverer(
}
f.lost[i] = protocol.InvalidByteCount
}
return f
}
func (f *mtuFinder) done() bool {
@ -165,7 +172,7 @@ func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount
f.inFlight = size
return ackhandler.Frame{
Frame: &wire.PingFrame{},
Handler: &mtuFinderAckHandler{f},
Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation},
}, size
}
@ -173,13 +180,26 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.min
}
func (f *mtuFinder) Reset(now time.Time, start, max protocol.ByteCount) {
f.generation++
f.lastProbeTime = now
f.lastProbeWasLost = false
f.inFlight = protocol.InvalidByteCount
f.init(start, max)
}
type mtuFinderAckHandler struct {
*mtuFinder
generation uint8
}
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// ACK for probe sent before reset
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnAcked callback called although there's no MTU probe packet in flight")
@ -207,10 +227,13 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
h.mtuIncreased(size)
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// probe sent before reset received
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")

View file

@ -22,9 +22,10 @@ type packer interface {
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)
MaybePackPTOProbePacket(protocol.EncryptionLevel, protocol.ByteCount, time.Time, protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackPathProbePacket(protocol.ConnectionID, ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte)
@ -57,6 +58,7 @@ type shortHeaderPacket struct {
Ack *wire.AckFrame
Length protocol.ByteCount
IsPathMTUProbePacket bool
IsPathProbePacket bool
// used for logging
DestConnID protocol.ConnectionID
@ -269,17 +271,17 @@ func (p *packetPacker) packConnectionClose(
if sealers[i] == nil {
continue
}
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
}
if encLevel == protocol.Encryption1RTT {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v)
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
} else {
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
}
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
if err != nil {
return nil, err
@ -707,7 +709,7 @@ func (p *packetPacker) composeNextPacket(
return pl
}
func (p *packetPacker) MaybePackProbePacket(
func (p *packetPacker) MaybePackPTOProbePacket(
encLevel protocol.EncryptionLevel,
maxPacketSize protocol.ByteCount,
now time.Time,
@ -792,6 +794,26 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
return packet, buffer, err
}
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, f ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
buf := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, nil, err
}
payload := payload{
frames: []ackhandler.Frame{f},
length: f.Frame.Length(v),
}
padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v)
if err != nil {
return shortHeaderPacket{}, nil, err
}
packet.IsPathProbePacket = true
return packet, buf, err
}
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{

145
vendor/github.com/quic-go/quic-go/path_manager.go generated vendored Normal file
View file

@ -0,0 +1,145 @@
package quic
import (
"crypto/rand"
"net"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type pathID int64
const maxPaths = 3
type path struct {
addr net.Addr
pathChallenge [8]byte
validated bool
rcvdNonProbing bool
}
type pathManager struct {
nextPathID pathID
paths map[pathID]*path
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
logger utils.Logger
}
func newPathManager(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
logger utils.Logger,
) *pathManager {
return &pathManager{
paths: make(map[pathID]*path),
getConnID: getConnID,
retireConnID: retireConnID,
logger: logger,
}
}
// Returns a path challenge frame if one should be sent.
// May return nil.
func (pm *pathManager) HandlePacket(p receivedPacket, isNonProbing bool) (_ protocol.ConnectionID, _ ackhandler.Frame, shouldSwitch bool) {
for _, path := range pm.paths {
if addrsEqual(path.addr, p.remoteAddr) {
// already sent a PATH_CHALLENGE for this path
if isNonProbing {
path.rcvdNonProbing = true
}
if pm.logger.Debug() {
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", p.remoteAddr, path.validated)
}
return protocol.ConnectionID{}, ackhandler.Frame{}, path.validated && path.rcvdNonProbing
}
}
if len(pm.paths) >= maxPaths {
if pm.logger.Debug() {
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", p.remoteAddr, len(pm.paths))
}
return protocol.ConnectionID{}, ackhandler.Frame{}, false
}
// previously unseen path, initiate path validation by sending a PATH_CHALLENGE
connID, ok := pm.getConnID(pm.nextPathID)
if !ok {
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", p.remoteAddr)
return protocol.ConnectionID{}, ackhandler.Frame{}, false
}
var b [8]byte
rand.Read(b[:])
pm.paths[pm.nextPathID] = &path{
addr: p.remoteAddr,
pathChallenge: b,
rcvdNonProbing: isNonProbing,
}
pm.nextPathID++
frame := ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: b},
Handler: (*pathManagerAckHandler)(pm),
}
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", p.remoteAddr)
return connID, frame, false
}
func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) {
for _, p := range pm.paths {
if f.Data == p.pathChallenge {
// path validated
p.validated = true
pm.logger.Debugf("path %s validated", p.addr)
break
}
}
}
// SwitchToPath is called when the connection switches to a new path
func (pm *pathManager) SwitchToPath(addr net.Addr) {
// retire all other paths
for id := range pm.paths {
if addrsEqual(pm.paths[id].addr, addr) {
pm.logger.Debugf("switching to path %d (%s)", id, addr)
continue
}
pm.retireConnID(id)
}
clear(pm.paths)
}
type pathManagerAckHandler pathManager
var _ ackhandler.FrameHandler = &pathManagerAckHandler{}
// Acknowledging the frame doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerAckHandler) OnAcked(f wire.Frame) {}
func (pm *pathManagerAckHandler) OnLost(f wire.Frame) {
// TODO: retransmit the packet the first time it is lost
pc := f.(*wire.PathChallengeFrame)
for id, path := range pm.paths {
if path.pathChallenge == pc.Data {
delete(pm.paths, id)
pm.retireConnID(id)
break
}
}
}
func addrsEqual(addr1, addr2 net.Addr) bool {
if addr1 == nil || addr2 == nil {
return false
}
a1, ok1 := addr1.(*net.UDPAddr)
a2, ok2 := addr2.(*net.UDPAddr)
if ok1 && ok2 {
return a1.IP.Equal(a2.IP) && a1.Port == a2.Port
}
return addr1.String() == addr2.String()
}

View file

@ -2,6 +2,7 @@ package quic
import (
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
@ -10,22 +11,29 @@ import (
// A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface {
Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
WriteTo([]byte, net.Addr) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
ChangeRemoteAddr(addr net.Addr, info packetInfo)
capabilities() connCapabilities
}
type remoteAddrInfo struct {
addr net.Addr
oob []byte
}
type sconn struct {
rawConn
localAddr net.Addr
remoteAddr net.Addr
localAddr net.Addr
remoteAddrInfo atomic.Pointer[remoteAddrInfo]
logger utils.Logger
packetInfoOOB []byte
// If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled.
gotGSOError bool
// Used to catch the error sometimes returned by the first sendmsg call on Linux,
@ -49,22 +57,26 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
// increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob)
oob = append(oob, make([]byte, 64)...)[:l]
return &sconn{
rawConn: c,
localAddr: localAddr,
remoteAddr: remote,
packetInfoOOB: oob,
logger: logger,
sc := &sconn{
rawConn: c,
localAddr: localAddr,
logger: logger,
}
sc.remoteAddrInfo.Store(&remoteAddrInfo{
addr: remote,
oob: oob,
})
return sc
}
func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn)
ai := c.remoteAddrInfo.Load()
err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn)
if err != nil && isGSOError(err) {
// disable GSO for future calls
c.gotGSOError = true
if c.logger.Debug() {
c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr)
c.logger.Debugf("GSO failed when sending to %s", ai.addr)
}
// send out the packets one by one
for len(p) > 0 {
@ -72,7 +84,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
if l > int(gsoSize) {
l = int(gsoSize)
}
if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil {
if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil {
return err
}
p = p[l:]
@ -91,6 +103,11 @@ func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16,
return err
}
func (c *sconn) WriteTo(b []byte, addr net.Addr) error {
_, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
return err
}
func (c *sconn) capabilities() connCapabilities {
capabilities := c.rawConn.capabilities()
if capabilities.GSO {
@ -99,5 +116,12 @@ func (c *sconn) capabilities() connCapabilities {
return capabilities
}
func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) {
c.remoteAddrInfo.Store(&remoteAddrInfo{
addr: addr,
oob: info.OOB(),
})
}
func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr }
func (c *sconn) LocalAddr() net.Addr { return c.localAddr }

View file

@ -1,9 +1,14 @@
package quic
import "github.com/quic-go/quic-go/internal/protocol"
import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
type sender interface {
Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
SendProbe(*packetBuffer, net.Addr)
Run() error
WouldBlock() bool
Available() <-chan struct{}
@ -57,6 +62,10 @@ func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
}
}
func (h *sendQueue) SendProbe(p *packetBuffer, addr net.Addr) {
h.conn.WriteTo(p.Data, addr)
}
func (h *sendQueue) WouldBlock() bool {
return len(h.queue) == sendQueueCapacity
}

5
vendor/modules.txt vendored
View file

@ -107,8 +107,8 @@ github.com/powerman/deepequal
# github.com/quic-go/qpack v0.5.1
## explicit; go 1.22
github.com/quic-go/qpack
# github.com/quic-go/quic-go v0.49.0
## explicit; go 1.22
# github.com/quic-go/quic-go v0.50.0
## explicit; go 1.23
github.com/quic-go/quic-go
github.com/quic-go/quic-go/http3
github.com/quic-go/quic-go/internal/ackhandler
@ -117,7 +117,6 @@ github.com/quic-go/quic-go/internal/flowcontrol
github.com/quic-go/quic-go/internal/handshake
github.com/quic-go/quic-go/internal/protocol
github.com/quic-go/quic-go/internal/qerr
github.com/quic-go/quic-go/internal/qtls
github.com/quic-go/quic-go/internal/utils
github.com/quic-go/quic-go/internal/utils/linkedlist
github.com/quic-go/quic-go/internal/utils/ringbuffer