Update deps

This commit is contained in:
Frank Denis 2023-06-22 10:36:33 +02:00
parent 4f3ce0cbae
commit f42b7dad17
93 changed files with 4303 additions and 1673 deletions

14
go.mod
View file

@ -18,9 +18,9 @@ require (
github.com/jedisct1/xsecretbox v0.0.0-20230513092623-8c0b2dff5e24 github.com/jedisct1/xsecretbox v0.0.0-20230513092623-8c0b2dff5e24
github.com/k-sone/critbitgo v1.4.0 github.com/k-sone/critbitgo v1.4.0
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.54 github.com/miekg/dns v1.1.55
github.com/powerman/check v1.7.0 github.com/powerman/check v1.7.0
github.com/quic-go/quic-go v0.35.1 github.com/quic-go/quic-go v0.36.0
golang.org/x/crypto v0.10.0 golang.org/x/crypto v0.10.0
golang.org/x/net v0.11.0 golang.org/x/net v0.11.0
golang.org/x/sys v0.9.0 golang.org/x/sys v0.9.0
@ -29,12 +29,12 @@ require (
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/mock v1.6.0 // indirect github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/hashicorp/go-syslog v1.0.0 // indirect github.com/hashicorp/go-syslog v1.0.0 // indirect
github.com/onsi/ginkgo/v2 v2.2.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/powerman/deepequal v0.1.0 // indirect github.com/powerman/deepequal v0.1.0 // indirect
@ -43,9 +43,9 @@ require (
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/smartystreets/goconvey v1.7.2 // indirect github.com/smartystreets/goconvey v1.7.2 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.10.0 // indirect
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.9.1 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
google.golang.org/grpc v1.53.0 // indirect google.golang.org/grpc v1.53.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

40
go.sum
View file

@ -12,13 +12,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 h1:3T8ZyTDp5QxTx3NU48JVb2u+75xc040fofcBaN+6jPA= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 h1:3T8ZyTDp5QxTx3NU48JVb2u+75xc040fofcBaN+6jPA=
github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185/go.mod h1:cFRxtTwTOJkz2x3rQUNCYKWC93yP1VKjR8NUhqFxZNU= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185/go.mod h1:cFRxtTwTOJkz2x3rQUNCYKWC93yP1VKjR8NUhqFxZNU=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
@ -54,11 +55,11 @@ github.com/k-sone/critbitgo v1.4.0 h1:l71cTyBGeh6X5ATh6Fibgw3+rtNT80BA0uNNWgkPrb
github.com/k-sone/critbitgo v1.4.0/go.mod h1:7E6pyoyADnFxlUBEKcnfS49b7SUAQGMK+OAp/UQvo0s= github.com/k-sone/critbitgo v1.4.0/go.mod h1:7E6pyoyADnFxlUBEKcnfS49b7SUAQGMK+OAp/UQvo0s=
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60= github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/miekg/dns v1.1.54 h1:5jon9mWcb0sFJGpnI99tOMhCPyJ+RPVz5b63MQG0VWI= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.54/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -73,15 +74,15 @@ github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc8
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= github.com/quic-go/quic-go v0.36.0 h1:JIrO7p7Ug6hssFcARjWDiqS2RAKJHCiwPxBAA989rbI=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.36.0/go.mod h1:zPetvwDlILVxt15n3hr3Gf/I3mDf7LpLKPhR4Ez0AZQ=
github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs=
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -90,8 +91,8 @@ golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o=
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -100,7 +101,7 @@ golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -120,8 +121,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -137,6 +138,5 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -183,14 +183,13 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
// This allows users of the library to implement their own connection management, // This allows users of the library to implement their own connection management,
// as opposed to Exchange, which will always use new connections and incur the added overhead // as opposed to Exchange, which will always use new connections and incur the added overhead
// that entails when using "tcp" and especially "tcp-tls" clients. // that entails when using "tcp" and especially "tcp-tls" clients.
//
// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
// prevent one cancellation from canceling all outstanding requests.
func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
return c.exchangeWithConnContext(context.Background(), m, conn) return c.ExchangeWithConnContext(context.Background(), m, conn)
} }
func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { // ExchangeWithConnContext has the same behaviour as ExchangeWithConn and
// additionally obeys deadlines from the passed Context.
func (c *Client) ExchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
opt := m.IsEdns0() opt := m.IsEdns0()
// If EDNS0 is used use that for size. // If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize { if opt != nil && opt.UDPSize() >= MinMsgSize {
@ -460,5 +459,5 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg,
} }
defer conn.Close() defer conn.Close()
return c.exchangeWithConnContext(ctx, m, conn) return c.ExchangeWithConnContext(ctx, m, conn)
} }

View file

@ -3,7 +3,7 @@ package dns
import "fmt" import "fmt"
// Version is current version of this library. // Version is current version of this library.
var Version = v{1, 1, 54} var Version = v{1, 1, 55}
// v holds the version of this library. // v holds the version of this library.
type v struct { type v struct {

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
"strconv"
"strings" "strings"
) )
@ -50,6 +51,37 @@ func NewWithNoColorBool(noColor bool) Formatter {
} }
func New(colorMode ColorMode) Formatter { func New(colorMode ColorMode) Formatter {
colorAliases := map[string]int{
"black": 0,
"red": 1,
"green": 2,
"yellow": 3,
"blue": 4,
"magenta": 5,
"cyan": 6,
"white": 7,
}
for colorAlias, n := range colorAliases {
colorAliases[fmt.Sprintf("bright-%s", colorAlias)] = n + 8
}
getColor := func(color, defaultEscapeCode string) string {
color = strings.ToUpper(strings.ReplaceAll(color, "-", "_"))
envVar := fmt.Sprintf("GINKGO_CLI_COLOR_%s", color)
envVarColor := os.Getenv(envVar)
if envVarColor == "" {
return defaultEscapeCode
}
if colorCode, ok := colorAliases[envVarColor]; ok {
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
colorCode, err := strconv.Atoi(envVarColor)
if err != nil || colorCode < 0 || colorCode > 255 {
return defaultEscapeCode
}
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
f := Formatter{ f := Formatter{
ColorMode: colorMode, ColorMode: colorMode,
colors: map[string]string{ colors: map[string]string{
@ -57,18 +89,18 @@ func New(colorMode ColorMode) Formatter {
"bold": "\x1b[1m", "bold": "\x1b[1m",
"underline": "\x1b[4m", "underline": "\x1b[4m",
"red": "\x1b[38;5;9m", "red": getColor("red", "\x1b[38;5;9m"),
"orange": "\x1b[38;5;214m", "orange": getColor("orange", "\x1b[38;5;214m"),
"coral": "\x1b[38;5;204m", "coral": getColor("coral", "\x1b[38;5;204m"),
"magenta": "\x1b[38;5;13m", "magenta": getColor("magenta", "\x1b[38;5;13m"),
"green": "\x1b[38;5;10m", "green": getColor("green", "\x1b[38;5;10m"),
"dark-green": "\x1b[38;5;28m", "dark-green": getColor("dark-green", "\x1b[38;5;28m"),
"yellow": "\x1b[38;5;11m", "yellow": getColor("yellow", "\x1b[38;5;11m"),
"light-yellow": "\x1b[38;5;228m", "light-yellow": getColor("light-yellow", "\x1b[38;5;228m"),
"cyan": "\x1b[38;5;14m", "cyan": getColor("cyan", "\x1b[38;5;14m"),
"gray": "\x1b[38;5;243m", "gray": getColor("gray", "\x1b[38;5;243m"),
"light-gray": "\x1b[38;5;246m", "light-gray": getColor("light-gray", "\x1b[38;5;246m"),
"blue": "\x1b[38;5;12m", "blue": getColor("blue", "\x1b[38;5;12m"),
}, },
} }
colors := []string{} colors := []string{}
@ -88,7 +120,10 @@ func (f Formatter) Fi(indentation uint, format string, args ...interface{}) stri
} }
func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string { func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string {
out := fmt.Sprintf(f.style(format), args...) out := f.style(format)
if len(args) > 0 {
out = fmt.Sprintf(out, args...)
}
if indentation == 0 && maxWidth == 0 { if indentation == 0 && maxWidth == 0 {
return out return out

View file

@ -39,6 +39,8 @@ func buildSpecs(args []string, cliConfig types.CLIConfig, goFlagsConfig types.Go
command.AbortWith("Found no test suites") command.AbortWith("Found no test suites")
} }
internal.VerifyCLIAndFrameworkVersion(suites)
opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers()) opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers())
opc.StartCompiling(suites, goFlagsConfig) opc.StartCompiling(suites, goFlagsConfig)

View file

@ -2,6 +2,7 @@ package generators
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"os" "os"
"text/template" "text/template"
@ -25,6 +26,9 @@ func BuildBootstrapCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate", {Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file", UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"}, Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the bootstrap template"},
}, },
&conf, &conf,
types.GinkgoFlagSections{}, types.GinkgoFlagSections{},
@ -57,6 +61,7 @@ type bootstrapData struct {
GomegaImport string GomegaImport string
GinkgoPackage string GinkgoPackage string
GomegaPackage string GomegaPackage string
CustomData map[string]any
} }
func generateBootstrap(conf GeneratorsConfig) { func generateBootstrap(conf GeneratorsConfig) {
@ -95,17 +100,32 @@ func generateBootstrap(conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate) tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom bootstrap file:", err) command.AbortIfError("Failed to read custom bootstrap file:", err)
templateText = string(tpl) templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom boostrap data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti { } else if conf.Agouti {
templateText = agoutiBootstrapText templateText = agoutiBootstrapText
} else { } else {
templateText = bootstrapText templateText = bootstrapText
} }
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) //Setting the option to explicitly fail if template is rendered trying to access missing key
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to parse bootstrap template:", err) command.AbortIfError("Failed to parse bootstrap template:", err)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
bootstrapTemplate.Execute(buf, data) //Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = bootstrapTemplate.Execute(buf, data)
command.AbortIfError("Failed to render bootstrap template:", err)
buf.WriteTo(f) buf.WriteTo(f)

View file

@ -2,6 +2,7 @@ package generators
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -28,6 +29,9 @@ func BuildGenerateCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate", {Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file", UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the test file template"}, Usage: "If specified, generate will use the contents of the file passed as the test file template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the test file template"},
}, },
&conf, &conf,
types.GinkgoFlagSections{}, types.GinkgoFlagSections{},
@ -64,6 +68,7 @@ type specData struct {
GomegaImport string GomegaImport string
GinkgoPackage string GinkgoPackage string
GomegaPackage string GomegaPackage string
CustomData map[string]any
} }
func generateTestFiles(conf GeneratorsConfig, args []string) { func generateTestFiles(conf GeneratorsConfig, args []string) {
@ -122,16 +127,31 @@ func generateTestFileForSubject(subject string, conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate) tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom template file:", err) command.AbortIfError("Failed to read custom template file:", err)
templateText = string(tpl) templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom template data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti { } else if conf.Agouti {
templateText = agoutiSpecText templateText = agoutiSpecText
} else { } else {
templateText = specText templateText = specText
} }
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText) //Setting the option to explicitly fail if template is rendered trying to access missing key
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to read parse test template:", err) command.AbortIfError("Failed to read parse test template:", err)
specTemplate.Execute(f, data) //Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = specTemplate.Execute(f, data)
command.AbortIfError("Failed to render bootstrap template:", err)
internal.GoFmt(targetFile) internal.GoFmt(targetFile)
} }

View file

@ -13,6 +13,7 @@ import (
type GeneratorsConfig struct { type GeneratorsConfig struct {
Agouti, NoDot, Internal bool Agouti, NoDot, Internal bool
CustomTemplate string CustomTemplate string
CustomTemplateData string
} }
func getPackageAndFormattedName() (string, string, string) { func getPackageAndFormattedName() (string, string, string) {

View file

@ -25,7 +25,16 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite
return suite return suite
} }
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./") ginkgoInvocationPath, _ := os.Getwd()
ginkgoInvocationPath, _ = filepath.Abs(ginkgoInvocationPath)
packagePath := suite.AbsPath()
pathToInvocationPath, err := filepath.Rel(packagePath, ginkgoInvocationPath)
if err != nil {
suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error())
return suite
}
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./", pathToInvocationPath)
if err != nil { if err != nil {
suite.State = TestSuiteStateFailedToCompile suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error()) suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error())

View file

@ -6,6 +6,7 @@ import (
"io" "io"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"syscall" "syscall"
@ -63,6 +64,12 @@ func checkForNoTestsWarning(buf *bytes.Buffer) bool {
} }
func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite { func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite {
// As we run the go test from the suite directory, make sure the cover profile is absolute
// and placed into the expected output directory when one is configured.
if goFlagsConfig.Cover && !filepath.IsAbs(goFlagsConfig.CoverProfile) {
goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0)
}
args, err := types.GenerateGoTestRunArgs(goFlagsConfig) args, err := types.GenerateGoTestRunArgs(goFlagsConfig)
command.AbortIfError("Failed to generate test run arguments", err) command.AbortIfError("Failed to generate test run arguments", err)
cmd, buf := buildAndStartCommand(suite, args, true) cmd, buf := buildAndStartCommand(suite, args, true)

View file

@ -0,0 +1,54 @@
package internal
import (
"fmt"
"os/exec"
"regexp"
"strings"
"github.com/onsi/ginkgo/v2/formatter"
"github.com/onsi/ginkgo/v2/types"
)
var versiorRe = regexp.MustCompile(`v(\d+\.\d+\.\d+)`)
func VerifyCLIAndFrameworkVersion(suites TestSuites) {
cliVersion := types.VERSION
mismatches := map[string][]string{}
for _, suite := range suites {
cmd := exec.Command("go", "list", "-m", "github.com/onsi/ginkgo/v2")
cmd.Dir = suite.Path
output, err := cmd.CombinedOutput()
if err != nil {
continue
}
components := strings.Split(string(output), " ")
if len(components) != 2 {
continue
}
matches := versiorRe.FindStringSubmatch(components[1])
if matches == nil || len(matches) != 2 {
continue
}
libraryVersion := matches[1]
if cliVersion != libraryVersion {
mismatches[libraryVersion] = append(mismatches[libraryVersion], suite.PackageName)
}
}
if len(mismatches) == 0 {
return
}
fmt.Println(formatter.F("{{red}}{{bold}}Ginkgo detected a version mismatch between the Ginkgo CLI and the version of Ginkgo imported by your packages:{{/}}"))
fmt.Println(formatter.Fi(1, "Ginkgo CLI Version:"))
fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}}", cliVersion))
fmt.Println(formatter.Fi(1, "Mismatched package versions found:"))
for version, packages := range mismatches {
fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}} used by %s", version, strings.Join(packages, ", ")))
}
fmt.Println("")
fmt.Println(formatter.Fiw(1, formatter.COLS, "{{gray}}Ginkgo will continue to attempt to run but you may see errors (including flag parsing errors) and should either update your go.mod or your version of the Ginkgo CLI to match.\n\nTo install the matching version of the CLI run\n {{bold}}go install github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file. Alternatively you can use\n {{bold}}go run github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file to invoke the matching version of the Ginkgo CLI.\n\nIf you are attempting to test multiple packages that each have a different version of the Ginkgo library with a single Ginkgo CLI that is currently unsupported.\n{{/}}"))
}

View file

@ -1,6 +1,7 @@
package outline package outline
import ( import (
"github.com/onsi/ginkgo/v2/types"
"go/ast" "go/ast"
"go/token" "go/token"
"strconv" "strconv"
@ -25,9 +26,10 @@ type ginkgoMetadata struct {
// End is the position of first character immediately after the spec or container block // End is the position of first character immediately after the spec or container block
End int `json:"end"` End int `json:"end"`
Spec bool `json:"spec"` Spec bool `json:"spec"`
Focused bool `json:"focused"` Focused bool `json:"focused"`
Pending bool `json:"pending"` Pending bool `json:"pending"`
Labels []string `json:"labels"`
} }
// ginkgoNode is used to construct the outline as a tree // ginkgoNode is used to construct the outline as a tree
@ -145,27 +147,35 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage
case "It", "Specify", "Entry": case "It", "Specify", "Entry":
n.Spec = true n.Spec = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FIt", "FSpecify", "FEntry": case "FIt", "FSpecify", "FEntry":
n.Spec = true n.Spec = true
n.Focused = true n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry": case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry":
n.Spec = true n.Spec = true
n.Pending = true n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "Context", "Describe", "When", "DescribeTable": case "Context", "Describe", "When", "DescribeTable":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FContext", "FDescribe", "FWhen", "FDescribeTable": case "FContext", "FDescribe", "FWhen", "FDescribeTable":
n.Focused = true n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable": case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable":
n.Pending = true n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "By": case "By":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
@ -216,3 +226,77 @@ func textFromCallExpr(ce *ast.CallExpr) (string, bool) {
return text.Value, true return text.Value, true
} }
} }
func labelFromCallExpr(ce *ast.CallExpr) []string {
labels := []string{}
if len(ce.Args) < 2 {
return labels
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Label" {
ls := extractLabels(expr)
for _, label := range ls {
labels = append(labels, label)
}
}
}
}
return labels
}
func extractLabels(expr *ast.CallExpr) []string {
out := []string{}
for _, arg := range expr.Args {
switch expr := arg.(type) {
case *ast.BasicLit:
if expr.Kind == token.STRING {
unquoted, err := strconv.Unquote(expr.Value)
if err != nil {
unquoted = expr.Value
}
validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{})
if err == nil {
out = append(out, validated)
}
}
}
}
return out
}
func pendingFromCallExpr(ce *ast.CallExpr) bool {
pending := false
if len(ce.Args) < 2 {
return pending
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Pending" {
pending = true
}
case *ast.Ident:
if expr.Name == "Pending" {
pending = true
}
}
}
return pending
}

View file

@ -47,7 +47,7 @@ func packageNameForImport(f *ast.File, path string) *string {
// or nil otherwise. // or nil otherwise.
func importSpec(f *ast.File, path string) *ast.ImportSpec { func importSpec(f *ast.File, path string) *ast.ImportSpec {
for _, s := range f.Imports { for _, s := range f.Imports {
if importPath(s) == path { if strings.HasPrefix(importPath(s), path) {
return s return s
} }
} }

View file

@ -85,12 +85,19 @@ func (o *outline) String() string {
// one 'width' of spaces for every level of nesting. // one 'width' of spaces for every level of nesting.
func (o *outline) StringIndent(width int) string { func (o *outline) StringIndent(width int) string {
var b strings.Builder var b strings.Builder
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n") b.WriteString("Name,Text,Start,End,Spec,Focused,Pending,Labels\n")
currentIndent := 0 currentIndent := 0
pre := func(n *ginkgoNode) { pre := func(n *ginkgoNode) {
b.WriteString(fmt.Sprintf("%*s", currentIndent, "")) b.WriteString(fmt.Sprintf("%*s", currentIndent, ""))
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending)) var labels string
if len(n.Labels) == 1 {
labels = n.Labels[0]
} else {
labels = strings.Join(n.Labels, ", ")
}
//enclosing labels in a double quoted comma separate listed so that when inmported into a CSV app the Labels column has comma separate strings
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t,\"%s\"\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending, labels))
currentIndent += width currentIndent += width
} }
post := func(n *ginkgoNode) { post := func(n *ginkgoNode) {

View file

@ -24,7 +24,7 @@ func BuildRunCommand() command.Command {
panic(err) panic(err)
} }
interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) interruptHandler := interrupt_handler.NewInterruptHandler(nil)
interrupt_handler.SwallowSigQuit() interrupt_handler.SwallowSigQuit()
return command.Command{ return command.Command{
@ -69,6 +69,8 @@ func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) {
skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter) skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter)
suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter) suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter)
internal.VerifyCLIAndFrameworkVersion(suites)
if len(skippedSuites) > 0 { if len(skippedSuites) > 0 {
fmt.Println("Will skip:") fmt.Println("Will skip:")
for _, skippedSuite := range skippedSuites { for _, skippedSuite := range skippedSuites {
@ -115,7 +117,7 @@ OUTER_LOOP:
} }
suites[suiteIdx] = suite suites[suiteIdx] = suite
if r.interruptHandler.Status().Interrupted { if r.interruptHandler.Status().Interrupted() {
opc.StopAndDrain() opc.StopAndDrain()
break OUTER_LOOP break OUTER_LOOP
} }

View file

@ -22,7 +22,7 @@ func BuildWatchCommand() command.Command {
if err != nil { if err != nil {
panic(err) panic(err)
} }
interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) interruptHandler := interrupt_handler.NewInterruptHandler(nil)
interrupt_handler.SwallowSigQuit() interrupt_handler.SwallowSigQuit()
return command.Command{ return command.Command{
@ -65,6 +65,8 @@ type SpecWatcher struct {
func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) {
suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter)
internal.VerifyCLIAndFrameworkVersion(suites)
if len(suites) == 0 { if len(suites) == 0 {
command.AbortWith("Found no test suites") command.AbortWith("Found no test suites")
} }
@ -127,7 +129,7 @@ func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) {
w.updateSeed() w.updateSeed()
w.computeSuccinctMode(len(suites)) w.computeSuccinctMode(len(suites))
for idx := range suites { for idx := range suites {
if w.interruptHandler.Status().Interrupted { if w.interruptHandler.Status().Interrupted() {
return return
} }
deltaTracker.WillRun(suites[idx]) deltaTracker.WillRun(suites[idx])
@ -156,7 +158,7 @@ func (w *SpecWatcher) compileAndRun(suite internal.TestSuite, additionalArgs []s
fmt.Println(suite.CompilationError.Error()) fmt.Println(suite.CompilationError.Error())
return suite return suite
} }
if w.interruptHandler.Status().Interrupted { if w.interruptHandler.Status().Interrupted() {
return suite return suite
} }
suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs) suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs)

View file

@ -1,7 +1,6 @@
package interrupt_handler package interrupt_handler
import ( import (
"fmt"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
@ -11,27 +10,29 @@ import (
"github.com/onsi/ginkgo/v2/internal/parallel_support" "github.com/onsi/ginkgo/v2/internal/parallel_support"
) )
const TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION = 30 * time.Second var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
const TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT = 10
const ABORT_POLLING_INTERVAL = 500 * time.Millisecond
const ABORT_REPEAT_INTERRUPT_DURATION = 30 * time.Second
type InterruptCause uint type InterruptCause uint
const ( const (
InterruptCauseInvalid InterruptCause = iota InterruptCauseInvalid InterruptCause = iota
InterruptCauseSignal InterruptCauseSignal
InterruptCauseTimeout
InterruptCauseAbortByOtherProcess InterruptCauseAbortByOtherProcess
) )
type InterruptLevel uint
const (
InterruptLevelUninterrupted InterruptLevel = iota
InterruptLevelCleanupAndReport
InterruptLevelReportOnly
InterruptLevelBailOut
)
func (ic InterruptCause) String() string { func (ic InterruptCause) String() string {
switch ic { switch ic {
case InterruptCauseSignal: case InterruptCauseSignal:
return "Interrupted by User" return "Interrupted by User"
case InterruptCauseTimeout:
return "Interrupted by Timeout"
case InterruptCauseAbortByOtherProcess: case InterruptCauseAbortByOtherProcess:
return "Interrupted by Other Ginkgo Process" return "Interrupted by Other Ginkgo Process"
} }
@ -39,37 +40,51 @@ func (ic InterruptCause) String() string {
} }
type InterruptStatus struct { type InterruptStatus struct {
Interrupted bool Channel chan interface{}
Channel chan interface{} Level InterruptLevel
Cause InterruptCause Cause InterruptCause
}
func (s InterruptStatus) Interrupted() bool {
return s.Level != InterruptLevelUninterrupted
}
func (s InterruptStatus) Message() string {
return s.Cause.String()
}
func (s InterruptStatus) ShouldIncludeProgressReport() bool {
return s.Cause != InterruptCauseAbortByOtherProcess
} }
type InterruptHandlerInterface interface { type InterruptHandlerInterface interface {
Status() InterruptStatus Status() InterruptStatus
SetInterruptPlaceholderMessage(string)
ClearInterruptPlaceholderMessage()
InterruptMessage() (string, bool)
} }
type InterruptHandler struct { type InterruptHandler struct {
c chan interface{} c chan interface{}
lock *sync.Mutex lock *sync.Mutex
interrupted bool level InterruptLevel
interruptPlaceholderMessage string cause InterruptCause
interruptCause InterruptCause client parallel_support.Client
client parallel_support.Client stop chan interface{}
stop chan interface{} signals []os.Signal
requestAbortCheck chan interface{}
} }
func NewInterruptHandler(timeout time.Duration, client parallel_support.Client) *InterruptHandler { func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
handler := &InterruptHandler{ if len(signals) == 0 {
c: make(chan interface{}), signals = []os.Signal{os.Interrupt, syscall.SIGTERM}
lock: &sync.Mutex{},
interrupted: false,
stop: make(chan interface{}),
client: client,
} }
handler.registerForInterrupts(timeout) handler := &InterruptHandler{
c: make(chan interface{}),
lock: &sync.Mutex{},
stop: make(chan interface{}),
requestAbortCheck: make(chan interface{}),
client: client,
signals: signals,
}
handler.registerForInterrupts()
return handler return handler
} }
@ -77,30 +92,28 @@ func (handler *InterruptHandler) Stop() {
close(handler.stop) close(handler.stop)
} }
func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) { func (handler *InterruptHandler) registerForInterrupts() {
// os signal handling // os signal handling
signalChannel := make(chan os.Signal, 1) signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM) signal.Notify(signalChannel, handler.signals...)
// timeout handling
var timeoutChannel <-chan time.Time
var timeoutTimer *time.Timer
if timeout > 0 {
timeoutTimer = time.NewTimer(timeout)
timeoutChannel = timeoutTimer.C
}
// cross-process abort handling // cross-process abort handling
var abortChannel chan bool var abortChannel chan interface{}
if handler.client != nil { if handler.client != nil {
abortChannel = make(chan bool) abortChannel = make(chan interface{})
go func() { go func() {
pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL) pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL)
for { for {
select { select {
case <-pollTicker.C: case <-pollTicker.C:
if handler.client.ShouldAbort() { if handler.client.ShouldAbort() {
abortChannel <- true close(abortChannel)
pollTicker.Stop()
return
}
case <-handler.requestAbortCheck:
if handler.client.ShouldAbort() {
close(abortChannel)
pollTicker.Stop() pollTicker.Stop()
return return
} }
@ -112,85 +125,53 @@ func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) {
}() }()
} }
// listen for any interrupt signals go func(abortChannel chan interface{}) {
// note that some (timeouts, cross-process aborts) will only trigger once
// for these we set up a ticker to keep interrupting the suite until it ends
// this ensures any `AfterEach` or `AfterSuite`s that get stuck cleaning up
// get interrupted eventually
go func() {
var interruptCause InterruptCause var interruptCause InterruptCause
var repeatChannel <-chan time.Time
var repeatTicker *time.Ticker
for { for {
select { select {
case <-signalChannel: case <-signalChannel:
interruptCause = InterruptCauseSignal interruptCause = InterruptCauseSignal
case <-timeoutChannel:
interruptCause = InterruptCauseTimeout
repeatInterruptTimeout := timeout / time.Duration(TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT)
if repeatInterruptTimeout > TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION {
repeatInterruptTimeout = TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION
}
timeoutTimer.Stop()
repeatTicker = time.NewTicker(repeatInterruptTimeout)
repeatChannel = repeatTicker.C
case <-abortChannel: case <-abortChannel:
interruptCause = InterruptCauseAbortByOtherProcess interruptCause = InterruptCauseAbortByOtherProcess
repeatTicker = time.NewTicker(ABORT_REPEAT_INTERRUPT_DURATION)
repeatChannel = repeatTicker.C
case <-repeatChannel:
//do nothing, just interrupt again using the same interruptCause
case <-handler.stop: case <-handler.stop:
if timeoutTimer != nil {
timeoutTimer.Stop()
}
if repeatTicker != nil {
repeatTicker.Stop()
}
signal.Stop(signalChannel) signal.Stop(signalChannel)
return return
} }
abortChannel = nil
handler.lock.Lock() handler.lock.Lock()
handler.interruptCause = interruptCause oldLevel := handler.level
if handler.interruptPlaceholderMessage != "" { handler.cause = interruptCause
fmt.Println(handler.interruptPlaceholderMessage) if handler.level == InterruptLevelUninterrupted {
handler.level = InterruptLevelCleanupAndReport
} else if handler.level == InterruptLevelCleanupAndReport {
handler.level = InterruptLevelReportOnly
} else if handler.level == InterruptLevelReportOnly {
handler.level = InterruptLevelBailOut
}
if handler.level != oldLevel {
close(handler.c)
handler.c = make(chan interface{})
} }
handler.interrupted = true
close(handler.c)
handler.c = make(chan interface{})
handler.lock.Unlock() handler.lock.Unlock()
} }
}() }(abortChannel)
} }
func (handler *InterruptHandler) Status() InterruptStatus { func (handler *InterruptHandler) Status() InterruptStatus {
handler.lock.Lock() handler.lock.Lock()
defer handler.lock.Unlock() status := InterruptStatus{
Level: handler.level,
return InterruptStatus{ Channel: handler.c,
Interrupted: handler.interrupted, Cause: handler.cause,
Channel: handler.c,
Cause: handler.interruptCause,
} }
} handler.lock.Unlock()
func (handler *InterruptHandler) SetInterruptPlaceholderMessage(message string) { if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
handler.lock.Lock() close(handler.requestAbortCheck)
defer handler.lock.Unlock() <-status.Channel
return handler.Status()
handler.interruptPlaceholderMessage = message }
}
return status
func (handler *InterruptHandler) ClearInterruptPlaceholderMessage() {
handler.lock.Lock()
defer handler.lock.Unlock()
handler.interruptPlaceholderMessage = ""
}
func (handler *InterruptHandler) InterruptMessage() (string, bool) {
handler.lock.Lock()
out := fmt.Sprintf("%s", handler.interruptCause.String())
defer handler.lock.Unlock()
return out, handler.interruptCause != InterruptCauseAbortByOtherProcess
} }

View file

@ -42,6 +42,8 @@ type Client interface {
PostSuiteWillBegin(report types.Report) error PostSuiteWillBegin(report types.Report) error
PostDidRun(report types.SpecReport) error PostDidRun(report types.SpecReport) error
PostSuiteDidEnd(report types.Report) error PostSuiteDidEnd(report types.Report) error
PostReportBeforeSuiteCompleted(state types.SpecState) error
BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error)
PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error
BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error)
BlockUntilNonprimaryProcsHaveFinished() error BlockUntilNonprimaryProcsHaveFinished() error

View file

@ -98,6 +98,19 @@ func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) er
return client.post("/progress-report", report) return client.post("/progress-report", report)
} }
func (client *httpClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.post("/report-before-suite-completed", state)
}
func (client *httpClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("/report-before-suite-state", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{ beforeSuiteState := BeforeSuiteState{
State: state, State: state,

View file

@ -26,7 +26,7 @@ type httpServer struct {
handler *ServerHandler handler *ServerHandler
} }
//Create a new server, automatically selecting a port // Create a new server, automatically selecting a port
func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) { func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -38,7 +38,7 @@ func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer,
}, nil }, nil
} }
//Start the server. You don't need to `go s.Start()`, just `s.Start()` // Start the server. You don't need to `go s.Start()`, just `s.Start()`
func (server *httpServer) Start() { func (server *httpServer) Start() {
httpServer := &http.Server{} httpServer := &http.Server{}
mux := http.NewServeMux() mux := http.NewServeMux()
@ -52,6 +52,8 @@ func (server *httpServer) Start() {
mux.HandleFunc("/progress-report", server.emitProgressReport) mux.HandleFunc("/progress-report", server.emitProgressReport)
//synchronization endpoints //synchronization endpoints
mux.HandleFunc("/report-before-suite-completed", server.handleReportBeforeSuiteCompleted)
mux.HandleFunc("/report-before-suite-state", server.handleReportBeforeSuiteState)
mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted) mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted)
mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState) mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState)
mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished) mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished)
@ -63,12 +65,12 @@ func (server *httpServer) Start() {
go httpServer.Serve(server.listener) go httpServer.Serve(server.listener)
} }
//Stop the server // Stop the server
func (server *httpServer) Close() { func (server *httpServer) Close() {
server.listener.Close() server.listener.Close()
} }
//The address the server can be reached it. Pass this into the `ForwardingReporter`. // The address the server can be reached it. Pass this into the `ForwardingReporter`.
func (server *httpServer) Address() string { func (server *httpServer) Address() string {
return "http://" + server.listener.Addr().String() return "http://" + server.listener.Addr().String()
} }
@ -93,7 +95,7 @@ func (server *httpServer) RegisterAlive(node int, alive func() bool) {
// Streaming Endpoints // Streaming Endpoints
// //
//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` // The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool { func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool {
defer request.Body.Close() defer request.Body.Close()
if json.NewDecoder(request.Body).Decode(object) != nil { if json.NewDecoder(request.Body).Decode(object) != nil {
@ -164,6 +166,23 @@ func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request
server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer) server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer)
} }
func (server *httpServer) handleReportBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if !server.decode(writer, request, &state) {
return
}
server.handleError(server.handler.ReportBeforeSuiteCompleted(state, voidReceiver), writer)
}
func (server *httpServer) handleReportBeforeSuiteState(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if server.handleError(server.handler.ReportBeforeSuiteState(voidSender, &state), writer) {
return
}
json.NewEncoder(writer).Encode(state)
}
func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) { func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var beforeSuiteState BeforeSuiteState var beforeSuiteState BeforeSuiteState
if !server.decode(writer, request, &beforeSuiteState) { if !server.decode(writer, request, &beforeSuiteState) {

View file

@ -76,6 +76,19 @@ func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) err
return client.client.Call("Server.EmitProgressReport", report, voidReceiver) return client.client.Call("Server.EmitProgressReport", report, voidReceiver)
} }
func (client *rpcClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.client.Call("Server.ReportBeforeSuiteCompleted", state, voidReceiver)
}
func (client *rpcClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("Server.ReportBeforeSuiteState", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{ beforeSuiteState := BeforeSuiteState{
State: state, State: state,

View file

@ -18,16 +18,17 @@ var voidSender Void
// It handles all the business logic to avoid duplication between the two servers // It handles all the business logic to avoid duplication between the two servers
type ServerHandler struct { type ServerHandler struct {
done chan interface{} done chan interface{}
outputDestination io.Writer outputDestination io.Writer
reporter reporters.Reporter reporter reporters.Reporter
alives []func() bool alives []func() bool
lock *sync.Mutex lock *sync.Mutex
beforeSuiteState BeforeSuiteState beforeSuiteState BeforeSuiteState
parallelTotal int reportBeforeSuiteState types.SpecState
counter int parallelTotal int
counterLock *sync.Mutex counter int
shouldAbort bool counterLock *sync.Mutex
shouldAbort bool
numSuiteDidBegins int numSuiteDidBegins int
numSuiteDidEnds int numSuiteDidEnds int
@ -37,11 +38,12 @@ type ServerHandler struct {
func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler { func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler {
return &ServerHandler{ return &ServerHandler{
reporter: reporter, reporter: reporter,
lock: &sync.Mutex{}, lock: &sync.Mutex{},
counterLock: &sync.Mutex{}, counterLock: &sync.Mutex{},
alives: make([]func() bool, parallelTotal), alives: make([]func() bool, parallelTotal),
beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid}, beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid},
parallelTotal: parallelTotal, parallelTotal: parallelTotal,
outputDestination: os.Stdout, outputDestination: os.Stdout,
done: make(chan interface{}), done: make(chan interface{}),
@ -140,6 +142,29 @@ func (handler *ServerHandler) haveNonprimaryProcsFinished() bool {
return true return true
} }
func (handler *ServerHandler) ReportBeforeSuiteCompleted(reportBeforeSuiteState types.SpecState, _ *Void) error {
handler.lock.Lock()
defer handler.lock.Unlock()
handler.reportBeforeSuiteState = reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) ReportBeforeSuiteState(_ Void, reportBeforeSuiteState *types.SpecState) error {
proc1IsAlive := handler.procIsAlive(1)
handler.lock.Lock()
defer handler.lock.Unlock()
if handler.reportBeforeSuiteState == types.SpecStateInvalid {
if proc1IsAlive {
return ErrorEarly
} else {
return ErrorGone
}
}
*reportBeforeSuiteState = handler.reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error { func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error {
handler.lock.Lock() handler.lock.Lock()
defer handler.lock.Unlock() defer handler.lock.Unlock()

View file

@ -12,6 +12,7 @@ import (
"io" "io"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/onsi/ginkgo/v2/formatter" "github.com/onsi/ginkgo/v2/formatter"
@ -23,13 +24,16 @@ type DefaultReporter struct {
writer io.Writer writer io.Writer
// managing the emission stream // managing the emission stream
lastChar string lastCharWasNewline bool
lastEmissionWasDelimiter bool lastEmissionWasDelimiter bool
// rendering // rendering
specDenoter string specDenoter string
retryDenoter string retryDenoter string
formatter formatter.Formatter formatter formatter.Formatter
runningInParallel bool
lock *sync.Mutex
} }
func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter { func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter {
@ -44,12 +48,13 @@ func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultRep
conf: conf, conf: conf,
writer: writer, writer: writer,
lastChar: "\n", lastCharWasNewline: true,
lastEmissionWasDelimiter: false, lastEmissionWasDelimiter: false,
specDenoter: "•", specDenoter: "•",
retryDenoter: "↺", retryDenoter: "↺",
formatter: formatter.NewWithNoColorBool(conf.NoColor), formatter: formatter.NewWithNoColorBool(conf.NoColor),
lock: &sync.Mutex{},
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
reporter.specDenoter = "+" reporter.specDenoter = "+"
@ -97,166 +102,10 @@ func (r *DefaultReporter) SuiteWillBegin(report types.Report) {
} }
} }
func (r *DefaultReporter) WillRun(report types.SpecReport) {
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) {
return
}
r.emitDelimiter()
indentation := uint(0)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText))
} else {
if len(report.ContainerHierarchyTexts) > 0 {
r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " "))
indentation = 1
}
line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText)
labels := report.Labels()
if len(labels) > 0 {
line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", "))
}
r.emitBlock(line)
}
r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
var header, highlightColor string
includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter
succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct)
hasGW := report.CapturedGinkgoWriterOutput != ""
hasStd := report.CapturedStdOutErr != ""
hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose)))
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
denoter = fmt.Sprintf("[%s]", report.LeafNodeType)
}
switch report.State {
case types.SpecStatePassed:
highlightColor, succinctLocationBlock = "{{green}}", v.LT(types.VerbosityLevelVerbose)
emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports {
header = fmt.Sprintf("%s PASSED", denoter)
} else {
return
}
} else {
header, stream = denoter, true
if report.NumAttempts > 1 {
header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false
}
if report.RunTime > r.conf.SlowSpecThreshold {
header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false
}
}
if hasStd || emitGinkgoWriterOutput || hasEmittableReports {
stream = false
}
case types.SpecStatePending:
highlightColor = "{{yellow}}"
includeRuntime, emitGinkgoWriterOutput = false, false
if v.Is(types.VerbosityLevelSuccinct) {
header, stream = "P", true
} else {
header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose)
}
case types.SpecStateSkipped:
highlightColor = "{{cyan}}"
if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) {
header = "S [SKIPPED]"
} else {
header, stream = "S", true
}
case types.SpecStateFailed:
highlightColor, header = "{{red}}", fmt.Sprintf("%s [FAILED]", denoter)
case types.SpecStatePanicked:
highlightColor, header = "{{magenta}}", fmt.Sprintf("%s! [PANICKED]", denoter)
case types.SpecStateInterrupted:
highlightColor, header = "{{orange}}", fmt.Sprintf("%s! [INTERRUPTED]", denoter)
case types.SpecStateAborted:
highlightColor, header = "{{coral}}", fmt.Sprintf("%s! [ABORTED]", denoter)
}
// Emit stream and return
if stream {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
// Emit header
r.emitDelimiter()
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
// Emit Code Location Block
r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false))
//Emit Stdout/Stderr Output
if hasStd {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}"))
}
//Emit Captured GinkgoWriter Output
if emitGinkgoWriterOutput && hasGW {
r.emitBlock("\n")
r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0)
}
if hasEmittableReports {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}"))
reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways)
if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) {
reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose)
}
for _, entry := range reportEntries {
r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT)))
if representation := entry.StringRepresentation(); representation != "" {
r.emitBlock(r.fi(3, representation))
}
}
r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.Message))
r.emitBlock(r.fi(1, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.Location))
if report.Failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.ForwardedPanic))
}
if r.conf.FullTrace || report.Failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(2, "%s", report.Failure.Location.FullStackTrace))
}
if !report.Failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(1, false, report.Failure.ProgressReport)
}
}
r.emitDelimiter()
}
func (r *DefaultReporter) SuiteDidEnd(report types.Report) { func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
failures := report.SpecReports.WithState(types.SpecStateFailureStates) failures := report.SpecReports.WithState(types.SpecStateFailureStates)
if len(failures) > 0 { if len(failures) > 0 {
r.emitBlock("\n\n") r.emitBlock("\n")
if len(failures) > 1 { if len(failures) > 1 {
r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures))) r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures)))
} else { } else {
@ -269,10 +118,12 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
highlightColor, heading = "{{magenta}}", "[PANICKED!]" highlightColor, heading = "{{magenta}}", "[PANICKED!]"
case types.SpecStateAborted: case types.SpecStateAborted:
highlightColor, heading = "{{coral}}", "[ABORTED]" highlightColor, heading = "{{coral}}", "[ABORTED]"
case types.SpecStateTimedout:
highlightColor, heading = "{{orange}}", "[TIMEDOUT]"
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
highlightColor, heading = "{{orange}}", "[INTERRUPTED]" highlightColor, heading = "{{orange}}", "[INTERRUPTED]"
} }
locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true) locationBlock := r.codeLocationBlock(specReport, highlightColor, false, true)
r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock)) r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock))
} }
} }
@ -313,28 +164,294 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
if specs.CountOfFlakedSpecs() > 0 { if specs.CountOfFlakedSpecs() > 0 {
r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs())) r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs()))
} }
if specs.CountOfRepeatedSpecs() > 0 {
r.emit(r.f("{{light-yellow}}{{bold}}%d Repeated{{/}} | ", specs.CountOfRepeatedSpecs()))
}
r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending))) r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending)))
r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped))) r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped)))
} }
} }
func (r *DefaultReporter) WillRun(report types.SpecReport) {
v := r.conf.Verbosity()
if v.LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) || report.RunningInParallel {
return
}
r.emitDelimiter(0)
r.emitBlock(r.f(r.codeLocationBlock(report, "{{/}}", v.Is(types.VerbosityLevelVeryVerbose), false)))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
inParallel := report.RunningInParallel
header := r.specDenoter
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("[%s]", report.LeafNodeType)
}
highlightColor := r.highlightColorForState(report.State)
// have we already been streaming the timeline?
timelineHasBeenStreaming := v.GTE(types.VerbosityLevelVerbose) && !inParallel
// should we show the timeline?
var timeline types.Timeline
showTimeline := !timelineHasBeenStreaming && (v.GTE(types.VerbosityLevelVerbose) || report.Failed())
if showTimeline {
timeline = report.Timeline().WithoutHiddenReportEntries()
keepVeryVerboseSpecEvents := v.Is(types.VerbosityLevelVeryVerbose) ||
(v.Is(types.VerbosityLevelVerbose) && r.conf.ShowNodeEvents) ||
(report.Failed() && r.conf.ShowNodeEvents)
if !keepVeryVerboseSpecEvents {
timeline = timeline.WithoutVeryVerboseSpecEvents()
}
if len(timeline) == 0 && report.CapturedGinkgoWriterOutput == "" {
// the timeline is completely empty - don't show it
showTimeline = false
}
if v.LT(types.VerbosityLevelVeryVerbose) && report.CapturedGinkgoWriterOutput == "" && len(timeline) > 0 {
//if we aren't -vv and the timeline only has a single failure, don't show it as it will appear at the end of the report
failure, isFailure := timeline[0].(types.Failure)
if isFailure && (len(timeline) == 1 || (len(timeline) == 2 && failure.AdditionalFailure != nil)) {
showTimeline = false
}
}
}
// should we have a separate section for always-visible reports?
showSeparateVisibilityAlwaysReportsSection := !timelineHasBeenStreaming && !showTimeline && report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways)
// should we have a separate section for captured stdout/stderr
showSeparateStdSection := inParallel && (report.CapturedStdOutErr != "")
// given all that - do we have any actual content to show? or are we a single denoter in a stream?
reportHasContent := v.Is(types.VerbosityLevelVeryVerbose) || showTimeline || showSeparateVisibilityAlwaysReportsSection || showSeparateStdSection || report.Failed() || (v.Is(types.VerbosityLevelVerbose) && !report.State.Is(types.SpecStateSkipped))
// should we show a runtime?
includeRuntime := !report.State.Is(types.SpecStateSkipped|types.SpecStatePending) || (report.State.Is(types.SpecStateSkipped) && report.Failure.Message != "")
// should we show the codelocation block?
showCodeLocation := !timelineHasBeenStreaming || !report.State.Is(types.SpecStatePassed)
switch report.State {
case types.SpecStatePassed:
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) && !reportHasContent {
return
}
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("%s PASSED", header)
}
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
header, reportHasContent = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), true
}
case types.SpecStatePending:
header = "P"
if v.GT(types.VerbosityLevelSuccinct) {
header, reportHasContent = "P [PENDING]", true
}
case types.SpecStateSkipped:
header = "S"
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && report.Failure.Message != "") {
header, reportHasContent = "S [SKIPPED]", true
}
default:
header = fmt.Sprintf("%s [%s]", header, r.humanReadableState(report.State))
if report.MaxMustPassRepeatedly > 1 {
header = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts)
}
}
// If we have no content to show, jsut emit the header and return
if !reportHasContent {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
// Emit header
if !timelineHasBeenStreaming {
r.emitDelimiter(0)
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
if showCodeLocation {
r.emitBlock(r.codeLocationBlock(report, highlightColor, v.Is(types.VerbosityLevelVeryVerbose), false))
}
//Emit Stdout/Stderr Output
if showSeparateStdSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(1, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< Captured StdOut/StdErr Output{{/}}"))
}
if showSeparateVisibilityAlwaysReportsSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Report Entries >>{{/}}"))
for _, entry := range report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) {
r.emitReportEntry(1, entry)
}
r.emitBlock(r.fi(1, "{{gray}}<< Report Entries{{/}}"))
}
if showTimeline {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Timeline >>{{/}}"))
r.emitTimeline(1, report, timeline)
r.emitBlock(r.fi(1, "{{gray}}<< Timeline{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() && !v.Is(types.VerbosityLevelVeryVerbose) {
r.emitBlock("\n")
r.emitFailure(1, report.State, report.Failure, true)
if len(report.AdditionalFailures) > 0 {
r.emitBlock(r.fi(1, "\nThere were {{bold}}{{red}}additional failures{{/}} detected. To view them in detail run {{bold}}ginkgo -vv{{/}}"))
}
}
r.emitDelimiter(0)
}
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
switch state {
case types.SpecStatePassed:
return "{{green}}"
case types.SpecStatePending:
return "{{yellow}}"
case types.SpecStateSkipped:
return "{{cyan}}"
case types.SpecStateFailed:
return "{{red}}"
case types.SpecStateTimedout:
return "{{orange}}"
case types.SpecStatePanicked:
return "{{magenta}}"
case types.SpecStateInterrupted:
return "{{orange}}"
case types.SpecStateAborted:
return "{{coral}}"
default:
return "{{gray}}"
}
}
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
return strings.ToUpper(state.String())
}
func (r *DefaultReporter) emitTimeline(indent uint, report types.SpecReport, timeline types.Timeline) {
isVeryVerbose := r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose)
gw := report.CapturedGinkgoWriterOutput
cursor := 0
for _, entry := range timeline {
tl := entry.GetTimelineLocation()
if tl.Offset < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:tl.Offset]))
cursor = tl.Offset
} else if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
cursor = len(gw)
}
switch x := entry.(type) {
case types.Failure:
if isVeryVerbose {
r.emitFailure(indent, report.State, x, false)
} else {
r.emitShortFailure(indent, report.State, x)
}
case types.AdditionalFailure:
if isVeryVerbose {
r.emitFailure(indent, x.State, x.Failure, true)
} else {
r.emitShortFailure(indent, x.State, x.Failure)
}
case types.ReportEntry:
r.emitReportEntry(indent, x)
case types.ProgressReport:
r.emitProgressReport(indent, false, x)
case types.SpecEvent:
if isVeryVerbose || !x.IsOnlyVisibleAtVeryVerbose() || r.conf.ShowNodeEvents {
r.emitSpecEvent(indent, x, isVeryVerbose)
}
}
}
if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
}
}
func (r *DefaultReporter) EmitFailure(state types.SpecState, failure types.Failure) {
if r.conf.Verbosity().Is(types.VerbosityLevelVerbose) {
r.emitShortFailure(1, state, failure)
} else if r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) {
r.emitFailure(1, state, failure, true)
}
}
func (r *DefaultReporter) emitShortFailure(indent uint, state types.SpecState, failure types.Failure) {
r.emitBlock(r.fi(indent, r.highlightColorForState(state)+"[%s]{{/}} in [%s] - %s {{gray}}@ %s{{/}}",
r.humanReadableState(state),
failure.FailureNodeType,
failure.Location,
failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT),
))
}
func (r *DefaultReporter) emitFailure(indent uint, state types.SpecState, failure types.Failure, includeAdditionalFailure bool) {
highlightColor := r.highlightColorForState(state)
r.emitBlock(r.fi(indent, highlightColor+"[%s] %s{{/}}", r.humanReadableState(state), failure.Message))
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}} {{gray}}@ %s{{/}}\n", failure.FailureNodeType, failure.Location, failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
if failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
}
if r.conf.FullTrace || failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
}
if !failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(indent, false, failure.ProgressReport)
}
if failure.AdditionalFailure != nil && includeAdditionalFailure {
r.emitBlock("\n")
r.emitFailure(indent, failure.AdditionalFailure.State, failure.AdditionalFailure.Failure, true)
}
}
func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) { func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) {
r.emitDelimiter() r.emitDelimiter(1)
if report.RunningInParallel { if report.RunningInParallel {
r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess)) r.emit(r.fi(1, "{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
} }
r.emitProgressReport(0, true, report) shouldEmitGW := report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)
r.emitDelimiter() r.emitProgressReport(1, shouldEmitGW, report)
r.emitDelimiter(1)
} }
func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) { func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) {
if report.Message != "" {
r.emitBlock(r.fi(indent, report.Message+"\n"))
indent += 1
}
if report.LeafNodeText != "" { if report.LeafNodeText != "" {
subjectIndent := indent
if len(report.ContainerHierarchyTexts) > 0 { if len(report.ContainerHierarchyTexts) > 0 {
r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " "))) r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " ")))
r.emit(" ") r.emit(" ")
subjectIndent = 0
} }
r.emit(r.f("{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond))) r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time().Sub(report.SpecStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation))
indent += 1 indent += 1
} }
@ -344,12 +461,12 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText)) r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText))
} }
r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond))) r.emit(r.f(" (Node Runtime: %s)\n", report.Time().Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation))
indent += 1 indent += 1
} }
if report.CurrentStepText != "" { if report.CurrentStepText != "" {
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond))) r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time().Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation))
indent += 1 indent += 1
} }
@ -358,9 +475,19 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
indent -= 1 indent -= 1
} }
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) { if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" {
r.emit("\n") r.emit("\n")
r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10) r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
limit, lines := 10, strings.Split(report.CapturedGinkgoWriterOutput, "\n")
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", report.CapturedGinkgoWriterOutput))
} else {
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
for _, line := range lines[len(lines)-limit-1:] {
r.emitBlock(r.fi(indent+1, "%s", line))
}
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
} }
if !report.SpecGoroutine().IsZero() { if !report.SpecGoroutine().IsZero() {
@ -369,6 +496,18 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emitGoroutines(indent, report.SpecGoroutine()) r.emitGoroutines(indent, report.SpecGoroutine())
} }
if len(report.AdditionalReports) > 0 {
r.emit("\n")
r.emitBlock(r.fi(indent, "{{gray}}Begin Additional Progress Reports >>{{/}}"))
for i, additionalReport := range report.AdditionalReports {
r.emit(r.fi(indent+1, additionalReport))
if i < len(report.AdditionalReports)-1 {
r.emitBlock(r.fi(indent+1, "{{gray}}%s{{/}}", strings.Repeat("-", 10)))
}
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Additional Progress Reports{{/}}"))
}
highlightedGoroutines := report.HighlightedGoroutines() highlightedGoroutines := report.HighlightedGoroutines()
if len(highlightedGoroutines) > 0 { if len(highlightedGoroutines) > 0 {
r.emit("\n") r.emit("\n")
@ -384,22 +523,48 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
} }
} }
func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) { func (r *DefaultReporter) EmitReportEntry(entry types.ReportEntry) {
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}")) if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || entry.Visibility == types.ReportEntryVisibilityNever {
if limit == 0 { return
r.emitBlock(r.fi(indent+1, "%s", output)) }
} else { r.emitReportEntry(1, entry)
lines := strings.Split(output, "\n") }
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", output)) func (r *DefaultReporter) emitReportEntry(indent uint, entry types.ReportEntry) {
} else { r.emitBlock(r.fi(indent, "{{bold}}"+entry.Name+"{{gray}} "+fmt.Sprintf("- %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))))
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}")) if representation := entry.StringRepresentation(); representation != "" {
for _, line := range lines[len(lines)-limit-1:] { r.emitBlock(r.fi(indent+1, representation))
r.emitBlock(r.fi(indent+1, "%s", line)) }
} }
}
func (r *DefaultReporter) EmitSpecEvent(event types.SpecEvent) {
v := r.conf.Verbosity()
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && (r.conf.ShowNodeEvents || !event.IsOnlyVisibleAtVeryVerbose())) {
r.emitSpecEvent(1, event, r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose))
}
}
func (r *DefaultReporter) emitSpecEvent(indent uint, event types.SpecEvent, includeLocation bool) {
location := ""
if includeLocation {
location = fmt.Sprintf("- %s ", event.CodeLocation.String())
}
switch event.SpecEventType {
case types.SpecEventInvalid:
return
case types.SpecEventByStart:
r.emitBlock(r.fi(indent, "{{bold}}STEP:{{/}} %s {{gray}}%s@ %s{{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventByEnd:
r.emitBlock(r.fi(indent, "{{bold}}END STEP:{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventNodeStart:
r.emitBlock(r.fi(indent, "> Enter {{bold}}[%s]{{/}} %s {{gray}}%s@ %s{{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventNodeEnd:
r.emitBlock(r.fi(indent, "< Exit {{bold}}[%s]{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventSpecRepeat:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{green}}Passed{{/}}{{bold}}. Repeating %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventSpecRetry:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{red}}Failed{{/}}{{bold}}. Retrying %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
} }
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
} }
func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) { func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) {
@ -457,31 +622,37 @@ func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) {
/* Emitting to the writer */ /* Emitting to the writer */
func (r *DefaultReporter) emit(s string) { func (r *DefaultReporter) emit(s string) {
if len(s) > 0 { r._emit(s, false, false)
r.lastChar = s[len(s)-1:]
r.lastEmissionWasDelimiter = false
r.writer.Write([]byte(s))
}
} }
func (r *DefaultReporter) emitBlock(s string) { func (r *DefaultReporter) emitBlock(s string) {
if len(s) > 0 { r._emit(s, true, false)
if r.lastChar != "\n" {
r.emit("\n")
}
r.emit(s)
if r.lastChar != "\n" {
r.emit("\n")
}
}
} }
func (r *DefaultReporter) emitDelimiter() { func (r *DefaultReporter) emitDelimiter(indent uint) {
if r.lastEmissionWasDelimiter { r._emit(r.fi(indent, "{{gray}}%s{{/}}", strings.Repeat("-", 30)), true, true)
}
// a bit ugly - but we're trying to minimize locking on this hot codepath
func (r *DefaultReporter) _emit(s string, block bool, isDelimiter bool) {
if len(s) == 0 {
return return
} }
r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30))) r.lock.Lock()
r.lastEmissionWasDelimiter = true defer r.lock.Unlock()
if isDelimiter && r.lastEmissionWasDelimiter {
return
}
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
}
r.lastCharWasNewline = (s[len(s)-1:] == "\n")
r.writer.Write([]byte(s))
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
r.lastCharWasNewline = true
}
r.lastEmissionWasDelimiter = isDelimiter
} }
/* Rendering text */ /* Rendering text */
@ -497,13 +668,14 @@ func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string {
return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"}) return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"})
} }
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string { func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, veryVerbose bool, usePreciseFailureLocation bool) string {
texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{} texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{}
texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...) texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText)) texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText))
} else { } else {
texts = append(texts, report.LeafNodeText) texts = append(texts, r.f(report.LeafNodeText))
} }
labels = append(labels, report.LeafNodeLabels) labels = append(labels, report.LeafNodeLabels)
locations = append(locations, report.LeafNodeLocation) locations = append(locations, report.LeafNodeLocation)
@ -513,24 +685,58 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
failureLocation = report.Failure.Location failureLocation = report.Failure.Location
} }
highlightIndex := -1
switch report.Failure.FailureNodeContext { switch report.Failure.FailureNodeContext {
case types.FailureNodeAtTopLevel: case types.FailureNodeAtTopLevel:
texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...) texts = append([]string{fmt.Sprintf("TOP-LEVEL [%s]", report.Failure.FailureNodeType)}, texts...)
locations = append([]types.CodeLocation{failureLocation}, locations...) locations = append([]types.CodeLocation{failureLocation}, locations...)
labels = append([][]string{{}}, labels...) labels = append([][]string{{}}, labels...)
highlightIndex = 0
case types.FailureNodeInContainer: case types.FailureNodeInContainer:
i := report.Failure.FailureNodeContainerIndex i := report.Failure.FailureNodeContainerIndex
texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType) texts[i] = fmt.Sprintf("%s [%s]", texts[i], report.Failure.FailureNodeType)
locations[i] = failureLocation locations[i] = failureLocation
highlightIndex = i
case types.FailureNodeIsLeafNode: case types.FailureNodeIsLeafNode:
i := len(texts) - 1 i := len(texts) - 1
texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText) texts[i] = fmt.Sprintf("[%s] %s", report.LeafNodeType, report.LeafNodeText)
locations[i] = failureLocation locations[i] = failureLocation
highlightIndex = i
default:
//there is no failure, so we highlight the leaf ndoe
highlightIndex = len(texts) - 1
} }
out := "" out := ""
if succinct { if veryVerbose {
out += r.f("%s", r.cycleJoin(texts, " ")) for i := range texts {
if i == highlightIndex {
out += r.fi(uint(i), highlightColor+"{{bold}}%s{{/}}", texts[i])
} else {
out += r.fi(uint(i), "%s", texts[i])
}
if len(labels[i]) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
}
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
}
} else {
for i := range texts {
style := "{{/}}"
if i%2 == 1 {
style = "{{gray}}"
}
if i == highlightIndex {
style = highlightColor + "{{bold}}"
}
out += r.f(style+"%s", texts[i])
if i < len(texts)-1 {
out += " "
} else {
out += r.f("{{/}}")
}
}
flattenedLabels := report.Labels() flattenedLabels := report.Labels()
if len(flattenedLabels) > 0 { if len(flattenedLabels) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", ")) out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", "))
@ -539,17 +745,15 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
if usePreciseFailureLocation { if usePreciseFailureLocation {
out += r.f("{{gray}}%s{{/}}", failureLocation) out += r.f("{{gray}}%s{{/}}", failureLocation)
} else { } else {
out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1]) leafLocation := locations[len(locations)-1]
} if (report.Failure.FailureNodeLocation != types.CodeLocation{}) && (report.Failure.FailureNodeLocation != leafLocation) {
} else { out += r.fi(1, highlightColor+"[%s]{{/}} {{gray}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.FailureNodeLocation)
for i := range texts { out += r.fi(1, "{{gray}}[%s] %s{{/}}", report.LeafNodeType, leafLocation)
out += r.fi(uint(i), "%s", texts[i]) } else {
if len(labels[i]) > 0 { out += r.f("{{gray}}%s{{/}}", leafLocation)
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
} }
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
} }
} }
return out return out
} }

View file

@ -35,7 +35,7 @@ func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Repor
FailOnPending: report.SuiteConfig.FailOnPending, FailOnPending: report.SuiteConfig.FailOnPending,
FailFast: report.SuiteConfig.FailFast, FailFast: report.SuiteConfig.FailFast,
FlakeAttempts: report.SuiteConfig.FlakeAttempts, FlakeAttempts: report.SuiteConfig.FlakeAttempts,
EmitSpecProgress: report.SuiteConfig.EmitSpecProgress, EmitSpecProgress: false,
DryRun: report.SuiteConfig.DryRun, DryRun: report.SuiteConfig.DryRun,
ParallelNode: report.SuiteConfig.ParallelProcess, ParallelNode: report.SuiteConfig.ParallelProcess,
ParallelTotal: report.SuiteConfig.ParallelTotal, ParallelTotal: report.SuiteConfig.ParallelTotal,

View file

@ -15,12 +15,32 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"time"
"github.com/onsi/ginkgo/v2/config" "github.com/onsi/ginkgo/v2/config"
"github.com/onsi/ginkgo/v2/types" "github.com/onsi/ginkgo/v2/types"
) )
type JunitReportConfig struct {
// Spec States for which no timeline should be emitted for system-err
// set this to types.SpecStatePassed|types.SpecStateSkipped|types.SpecStatePending to only match failing specs
OmitTimelinesForSpecState types.SpecState
// Enable OmitFailureMessageAttr to prevent failure messages appearing in the "message" attribute of the Failure and Error tags
OmitFailureMessageAttr bool
//Enable OmitCapturedStdOutErr to prevent captured stdout/stderr appearing in system-out
OmitCapturedStdOutErr bool
// Enable OmitSpecLabels to prevent labels from appearing in the spec name
OmitSpecLabels bool
// Enable OmitLeafNodeType to prevent the spec leaf node type from appearing in the spec name
OmitLeafNodeType bool
// Enable OmitSuiteSetupNodes to prevent the creation of testcase entries for setup nodes
OmitSuiteSetupNodes bool
}
type JUnitTestSuites struct { type JUnitTestSuites struct {
XMLName xml.Name `xml:"testsuites"` XMLName xml.Name `xml:"testsuites"`
// Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite) // Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite)
@ -128,6 +148,10 @@ type JUnitFailure struct {
} }
func GenerateJUnitReport(report types.Report, dst string) error { func GenerateJUnitReport(report types.Report, dst string) error {
return GenerateJUnitReportWithConfig(report, dst, JunitReportConfig{})
}
func GenerateJUnitReportWithConfig(report types.Report, dst string, config JunitReportConfig) error {
suite := JUnitTestSuite{ suite := JUnitTestSuite{
Name: report.SuiteDescription, Name: report.SuiteDescription,
Package: report.SuitePath, Package: report.SuitePath,
@ -149,7 +173,6 @@ func GenerateJUnitReport(report types.Report, dst string) error {
{"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)}, {"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)},
{"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)}, {"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)},
{"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)}, {"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)},
{"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)},
{"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)}, {"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)},
{"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)}, {"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)},
{"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode}, {"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode},
@ -157,22 +180,33 @@ func GenerateJUnitReport(report types.Report, dst string) error {
}, },
} }
for _, spec := range report.SpecReports { for _, spec := range report.SpecReports {
if config.OmitSuiteSetupNodes && spec.LeafNodeType != types.NodeTypeIt {
continue
}
name := fmt.Sprintf("[%s]", spec.LeafNodeType) name := fmt.Sprintf("[%s]", spec.LeafNodeType)
if config.OmitLeafNodeType {
name = ""
}
if spec.FullText() != "" { if spec.FullText() != "" {
name = name + " " + spec.FullText() name = name + " " + spec.FullText()
} }
labels := spec.Labels() labels := spec.Labels()
if len(labels) > 0 { if len(labels) > 0 && !config.OmitSpecLabels {
name = name + " [" + strings.Join(labels, ", ") + "]" name = name + " [" + strings.Join(labels, ", ") + "]"
} }
name = strings.TrimSpace(name)
test := JUnitTestCase{ test := JUnitTestCase{
Name: name, Name: name,
Classname: report.SuiteDescription, Classname: report.SuiteDescription,
Status: spec.State.String(), Status: spec.State.String(),
Time: spec.RunTime.Seconds(), Time: spec.RunTime.Seconds(),
SystemOut: systemOutForUnstructuredReporters(spec), }
SystemErr: systemErrForUnstructuredReporters(spec), if !spec.State.Is(config.OmitTimelinesForSpecState) {
test.SystemErr = systemErrForUnstructuredReporters(spec)
}
if !config.OmitCapturedStdOutErr {
test.SystemOut = systemOutForUnstructuredReporters(spec)
} }
suite.Tests += 1 suite.Tests += 1
@ -191,28 +225,50 @@ func GenerateJUnitReport(report types.Report, dst string) error {
test.Failure = &JUnitFailure{ test.Failure = &JUnitFailure{
Message: spec.Failure.Message, Message: spec.Failure.Message,
Type: "failed", Type: "failed",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
}
suite.Failures += 1
case types.SpecStateTimedout:
test.Failure = &JUnitFailure{
Message: spec.Failure.Message,
Type: "timedout",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
} }
suite.Failures += 1 suite.Failures += 1
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
test.Error = &JUnitError{ test.Error = &JUnitError{
Message: "interrupted", Message: spec.Failure.Message,
Type: "interrupted", Type: "interrupted",
Description: interruptDescriptionForUnstructuredReporters(spec.Failure), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
case types.SpecStateAborted: case types.SpecStateAborted:
test.Failure = &JUnitFailure{ test.Failure = &JUnitFailure{
Message: spec.Failure.Message, Message: spec.Failure.Message,
Type: "aborted", Type: "aborted",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
case types.SpecStatePanicked: case types.SpecStatePanicked:
test.Error = &JUnitError{ test.Error = &JUnitError{
Message: spec.Failure.ForwardedPanic, Message: spec.Failure.ForwardedPanic,
Type: "panicked", Type: "panicked",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
} }
@ -278,52 +334,27 @@ func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error)
return messages, f.Close() return messages, f.Close()
} }
func interruptDescriptionForUnstructuredReporters(failure types.Failure) string { func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string {
out := &strings.Builder{} out := &strings.Builder{}
out.WriteString(failure.Message + "\n") NewDefaultReporter(types.ReporterConfig{NoColor: true, VeryVerbose: true}, out).emitFailure(0, spec.State, spec.Failure, true)
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(failure.ProgressReport) if len(spec.AdditionalFailures) > 0 {
out.WriteString("\nThere were additional failures detected after the initial failure. These are visible in the timeline\n")
}
return out.String() return out.String()
} }
func systemErrForUnstructuredReporters(spec types.SpecReport) string { func systemErrForUnstructuredReporters(spec types.SpecReport) string {
return RenderTimeline(spec, true)
}
func RenderTimeline(spec types.SpecReport, noColor bool) string {
out := &strings.Builder{} out := &strings.Builder{}
gw := spec.CapturedGinkgoWriterOutput NewDefaultReporter(types.ReporterConfig{NoColor: noColor, VeryVerbose: true}, out).emitTimeline(0, spec, spec.Timeline())
cursor := 0
for _, pr := range spec.ProgressReports {
if cursor < pr.GinkgoWriterOffset {
if pr.GinkgoWriterOffset < len(gw) {
out.WriteString(gw[cursor:pr.GinkgoWriterOffset])
cursor = pr.GinkgoWriterOffset
} else if cursor < len(gw) {
out.WriteString(gw[cursor:])
cursor = len(gw)
}
}
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr)
}
if cursor < len(gw) {
out.WriteString(gw[cursor:])
}
return out.String() return out.String()
} }
func systemOutForUnstructuredReporters(spec types.SpecReport) string { func systemOutForUnstructuredReporters(spec types.SpecReport) string {
systemOut := spec.CapturedStdOutErr return spec.CapturedStdOutErr
if len(spec.ReportEntries) > 0 {
systemOut += "\nReport Entries:\n"
for i, entry := range spec.ReportEntries {
systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano))
if representation := entry.StringRepresentation(); representation != "" {
systemOut += representation + "\n"
}
if i+1 < len(spec.ReportEntries) {
systemOut += "--\n"
}
}
}
return systemOut
} }
// Deprecated JUnitReporter (so folks can still compile their suites) // Deprecated JUnitReporter (so folks can still compile their suites)

View file

@ -9,13 +9,21 @@ type Reporter interface {
WillRun(report types.SpecReport) WillRun(report types.SpecReport)
DidRun(report types.SpecReport) DidRun(report types.SpecReport)
SuiteDidEnd(report types.Report) SuiteDidEnd(report types.Report)
//Timeline emission
EmitFailure(state types.SpecState, failure types.Failure)
EmitProgressReport(progressReport types.ProgressReport) EmitProgressReport(progressReport types.ProgressReport)
EmitReportEntry(entry types.ReportEntry)
EmitSpecEvent(event types.SpecEvent)
} }
type NoopReporter struct{} type NoopReporter struct{}
func (n NoopReporter) SuiteWillBegin(report types.Report) {} func (n NoopReporter) SuiteWillBegin(report types.Report) {}
func (n NoopReporter) WillRun(report types.SpecReport) {} func (n NoopReporter) WillRun(report types.SpecReport) {}
func (n NoopReporter) DidRun(report types.SpecReport) {} func (n NoopReporter) DidRun(report types.SpecReport) {}
func (n NoopReporter) SuiteDidEnd(report types.Report) {} func (n NoopReporter) SuiteDidEnd(report types.Report) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {} func (n NoopReporter) EmitFailure(state types.SpecState, failure types.Failure) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {}
func (n NoopReporter) EmitReportEntry(entry types.ReportEntry) {}
func (n NoopReporter) EmitSpecEvent(event types.SpecEvent) {}

View file

@ -60,15 +60,19 @@ func GenerateTeamcityReport(report types.Report, dst string) error {
} }
fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message)) fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message))
case types.SpecStateFailed: case types.SpecStateFailed:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStatePanicked: case types.SpecStatePanicked:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details))
case types.SpecStateTimedout:
details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='timedout - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted' details='%s']\n", name, tcEscape(interruptDescriptionForUnstructuredReporters(spec.Failure))) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStateAborted: case types.SpecStateAborted:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
} }

View file

@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
) )
type CodeLocation struct { type CodeLocation struct {
@ -38,6 +39,73 @@ func (codeLocation CodeLocation) ContentsOfLine() string {
return lines[codeLocation.LineNumber-1] return lines[codeLocation.LineNumber-1]
} }
type codeLocationLocator struct {
pcs map[uintptr]bool
helpers map[string]bool
lock *sync.Mutex
}
func (c *codeLocationLocator) addHelper(pc uintptr) {
c.lock.Lock()
defer c.lock.Unlock()
if c.pcs[pc] {
return
}
c.lock.Unlock()
f := runtime.FuncForPC(pc)
c.lock.Lock()
if f == nil {
return
}
c.helpers[f.Name()] = true
c.pcs[pc] = true
}
func (c *codeLocationLocator) hasHelper(name string) bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.helpers[name]
}
func (c *codeLocationLocator) getCodeLocation(skip int) CodeLocation {
pc := make([]uintptr, 40)
n := runtime.Callers(skip+2, pc)
if n == 0 {
return CodeLocation{}
}
pc = pc[:n]
frames := runtime.CallersFrames(pc)
for {
frame, more := frames.Next()
if !c.hasHelper(frame.Function) {
return CodeLocation{FileName: frame.File, LineNumber: frame.Line}
}
if !more {
break
}
}
return CodeLocation{}
}
var clLocator = &codeLocationLocator{
pcs: map[uintptr]bool{},
helpers: map[string]bool{},
lock: &sync.Mutex{},
}
// MarkAsHelper is used by GinkgoHelper to mark the caller (appropriately offset by skip)as a helper. You can use this directly if you need to provide an optional `skip` to mark functions further up the call stack as helpers.
func MarkAsHelper(optionalSkip ...int) {
skip := 1
if len(optionalSkip) > 0 {
skip += optionalSkip[0]
}
pc, _, _, ok := runtime.Caller(skip)
if ok {
clLocator.addHelper(pc)
}
}
func NewCustomCodeLocation(message string) CodeLocation { func NewCustomCodeLocation(message string) CodeLocation {
return CodeLocation{ return CodeLocation{
CustomMessage: message, CustomMessage: message,
@ -45,14 +113,13 @@ func NewCustomCodeLocation(message string) CodeLocation {
} }
func NewCodeLocation(skip int) CodeLocation { func NewCodeLocation(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1) return clLocator.getCodeLocation(skip + 1)
return CodeLocation{FileName: file, LineNumber: line}
} }
func NewCodeLocationWithStackTrace(skip int) CodeLocation { func NewCodeLocationWithStackTrace(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1) cl := clLocator.getCodeLocation(skip + 1)
stackTrace := PruneStack(string(debug.Stack()), skip+1) cl.FullStackTrace = PruneStack(string(debug.Stack()), skip+1)
return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} return cl
} }
// PruneStack removes references to functions that are internal to Ginkgo // PruneStack removes references to functions that are internal to Ginkgo

View file

@ -8,6 +8,7 @@ package types
import ( import (
"flag" "flag"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -26,13 +27,14 @@ type SuiteConfig struct {
FailOnPending bool FailOnPending bool
FailFast bool FailFast bool
FlakeAttempts int FlakeAttempts int
EmitSpecProgress bool
DryRun bool DryRun bool
PollProgressAfter time.Duration PollProgressAfter time.Duration
PollProgressInterval time.Duration PollProgressInterval time.Duration
Timeout time.Duration Timeout time.Duration
EmitSpecProgress bool // this is deprecated but its removal is causing compile issue for some users that were setting it manually
OutputInterceptorMode string OutputInterceptorMode string
SourceRoots []string SourceRoots []string
GracePeriod time.Duration
ParallelProcess int ParallelProcess int
ParallelTotal int ParallelTotal int
@ -45,6 +47,7 @@ func NewDefaultSuiteConfig() SuiteConfig {
Timeout: time.Hour, Timeout: time.Hour,
ParallelProcess: 1, ParallelProcess: 1,
ParallelTotal: 1, ParallelTotal: 1,
GracePeriod: 30 * time.Second,
} }
} }
@ -79,13 +82,12 @@ func (vl VerbosityLevel) LT(comp VerbosityLevel) bool {
// Configuration for Ginkgo's reporter // Configuration for Ginkgo's reporter
type ReporterConfig struct { type ReporterConfig struct {
NoColor bool NoColor bool
SlowSpecThreshold time.Duration Succinct bool
Succinct bool Verbose bool
Verbose bool VeryVerbose bool
VeryVerbose bool FullTrace bool
FullTrace bool ShowNodeEvents bool
AlwaysEmitGinkgoWriter bool
JSONReport string JSONReport string
JUnitReport string JUnitReport string
@ -108,9 +110,7 @@ func (rc ReporterConfig) WillGenerateReport() bool {
} }
func NewDefaultReporterConfig() ReporterConfig { func NewDefaultReporterConfig() ReporterConfig {
return ReporterConfig{ return ReporterConfig{}
SlowSpecThreshold: 5 * time.Second,
}
} }
// Configuration for the Ginkgo CLI // Configuration for the Ginkgo CLI
@ -233,6 +233,9 @@ type deprecatedConfig struct {
SlowSpecThresholdWithFLoatUnits float64 SlowSpecThresholdWithFLoatUnits float64
Stream bool Stream bool
Notify bool Notify bool
EmitSpecProgress bool
SlowSpecThreshold time.Duration
AlwaysEmitGinkgoWriter bool
} }
// Flags // Flags
@ -273,8 +276,6 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags", {KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."}, Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."},
{KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug",
Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."},
{KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0", {KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0",
Usage: "Emit node progress reports periodically if node hasn't completed after this duration."}, Usage: "Emit node progress reports periodically if node hasn't completed after this duration."},
{KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s", {KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s",
@ -283,6 +284,8 @@ var SuiteConfigFlags = GinkgoFlags{
Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."}, Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."},
{KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h", {KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h",
Usage: "Test suite fails if it does not complete within the specified timeout."}, Usage: "Test suite fails if it does not complete within the specified timeout."},
{KeyPath: "S.GracePeriod", Name: "grace-period", SectionKey: "debug", UsageDefaultValue: "30s",
Usage: "When interrupted, Ginkgo will wait for GracePeriod for the current running node to exit before moving on to the next one."},
{KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none", {KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none",
Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."}, Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."},
@ -299,6 +302,8 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.EmitSpecProgress", DeprecatedName: "progress", SectionKey: "debug",
DeprecatedVersion: "2.5.0", Usage: ". The functionality provided by --progress was confusing and is no longer needed. Use --show-node-events instead to see node entry and exit events included in the timeline of failed and verbose specs. Or you can run with -vv to always see all node events. Lastly, --poll-progress-after and the PollProgressAfter decorator now provide a better mechanism for debugging specs that tend to get stuck."},
} }
// ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI) // ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI)
@ -315,8 +320,6 @@ var ParallelConfigFlags = GinkgoFlags{
var ReporterConfigFlags = GinkgoFlags{ var ReporterConfigFlags = GinkgoFlags{
{KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags", {KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, suppress color output in default reporter."}, Usage: "If set, suppress color output in default reporter."},
{KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s",
Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."},
{KeyPath: "R.Verbose", Name: "v", SectionKey: "output", {KeyPath: "R.Verbose", Name: "v", SectionKey: "output",
Usage: "If set, emits more output including GinkgoWriter contents."}, Usage: "If set, emits more output including GinkgoWriter contents."},
{KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output", {KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output",
@ -325,8 +328,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "If set, default reporter prints out a very succinct report"}, Usage: "If set, default reporter prints out a very succinct report"},
{KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output", {KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output",
Usage: "If set, default reporter prints out the full stack trace when a failure occurs"}, Usage: "If set, default reporter prints out the full stack trace when a failure occurs"},
{KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed", {KeyPath: "R.ShowNodeEvents", Name: "show-node-events", SectionKey: "output",
Usage: "If set, default reporter prints out captured output of passed tests."}, Usage: "If set, default reporter prints node > Enter and < Exit events when specs fail"},
{KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output", {KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output",
Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."}, Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."},
@ -339,6 +342,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"}, Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"},
{KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.SlowSpecThreshold", DeprecatedName: "slow-spec-threshold", SectionKey: "output", Usage: "--slow-spec-threshold has been deprecated and will be removed in a future version of Ginkgo. This feature has proved to be more noisy than useful. You can use --poll-progress-after, instead, to get more actionable feedback about potentially slow specs and understand where they might be getting stuck.", DeprecatedVersion: "2.5.0"},
{KeyPath: "D.AlwaysEmitGinkgoWriter", DeprecatedName: "always-emit-ginkgo-writer", SectionKey: "output", Usage: " - use -v instead, or one of Ginkgo's machine-readable report formats to get GinkgoWriter output for passing specs."},
} }
// BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process // BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process
@ -390,6 +395,10 @@ func VetConfig(flagSet GinkgoFlagSet, suiteConfig SuiteConfig, reporterConfig Re
errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration()) errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration())
} }
if suiteConfig.GracePeriod <= 0 {
errors = append(errors, GinkgoErrors.GracePeriodCannotBeZero())
}
if len(suiteConfig.FocusFiles) > 0 { if len(suiteConfig.FocusFiles) > 0 {
_, err := ParseFileFilters(suiteConfig.FocusFiles) _, err := ParseFileFilters(suiteConfig.FocusFiles)
if err != nil { if err != nil {
@ -592,13 +601,29 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo
} }
// GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test // GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) { func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string, pathToInvocationPath string) ([]string, error) {
// if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure // if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure
// the built test binary can generate a coverprofile // the built test binary can generate a coverprofile
if goFlagsConfig.CoverProfile != "" { if goFlagsConfig.CoverProfile != "" {
goFlagsConfig.Cover = true goFlagsConfig.Cover = true
} }
if goFlagsConfig.CoverPkg != "" {
coverPkgs := strings.Split(goFlagsConfig.CoverPkg, ",")
adjustedCoverPkgs := make([]string, len(coverPkgs))
for i, coverPkg := range coverPkgs {
coverPkg = strings.Trim(coverPkg, " ")
if strings.HasPrefix(coverPkg, "./") {
// this is a relative coverPkg - we need to reroot it
adjustedCoverPkgs[i] = "./" + filepath.Join(pathToInvocationPath, strings.TrimPrefix(coverPkg, "./"))
} else {
// this is a package name - don't touch it
adjustedCoverPkgs[i] = coverPkg
}
}
goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",")
}
args := []string{"test", "-c", "-o", destination, packageToBuild} args := []string{"test", "-c", "-o", destination, packageToBuild}
goArgs, err := GenerateFlagArgs( goArgs, err := GenerateFlagArgs(
GoBuildFlags, GoBuildFlags,

View file

@ -38,7 +38,7 @@ func (d deprecations) Async() Deprecation {
func (d deprecations) Measure() Deprecation { func (d deprecations) Measure() Deprecation {
return Deprecation{ return Deprecation{
Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.", Message: "Measure is deprecated and has been removed from Ginkgo V2. Any Measure tests in your spec will not run. Please migrate to gomega/gmeasure.",
DocLink: "removed-measure", DocLink: "removed-measure",
Version: "1.16.3", Version: "1.16.3",
} }
@ -83,6 +83,13 @@ func (d deprecations) Nodot() Deprecation {
} }
} }
func (d deprecations) SuppressProgressReporting() Deprecation {
return Deprecation{
Message: "Improvements to how reporters emit timeline information means that SuppressProgressReporting is no longer necessary and has been deprecated.",
Version: "2.5.0",
}
}
type DeprecationTracker struct { type DeprecationTracker struct {
deprecations map[Deprecation][]CodeLocation deprecations map[Deprecation][]CodeLocation
lock *sync.Mutex lock *sync.Mutex

View file

@ -108,8 +108,8 @@ Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/
func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error { func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) { if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite" docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
} }
return GinkgoError{ return GinkgoError{
@ -125,8 +125,8 @@ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocatio
func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error { func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) { if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite" docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
} }
return GinkgoError{ return GinkgoError{
@ -180,6 +180,15 @@ func (g ginkgoErrors) InvalidDeclarationOfFocusedAndPending(cl CodeLocation, nod
} }
} }
func (g ginkgoErrors) InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid Combination of Decorators: FlakeAttempts and MustPassRepeatedly",
Message: formatter.F(`[%s] node was decorated with both FlakeAttempts and MustPassRepeatedly. At most one is allowed.`, nodeType),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error { func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error {
return GinkgoError{ return GinkgoError{
Heading: "Unknown Decorator", Heading: "Unknown Decorator",
@ -189,20 +198,55 @@ func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decor
} }
} }
func (g ginkgoErrors) InvalidBodyTypeForContainer(t reflect.Type, cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. You passed {{bold}}%s{{/}} instead.`, nodeType, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error {
mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}"
if nodeType.Is(NodeTypeContainer) {
mustGet = "{{bold}}func(){{/}}"
}
return GinkgoError{ return GinkgoError{
Heading: "Invalid Function", Heading: "Invalid Function",
Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. Message: formatter.F(`[%s] node must be passed `+mustGet+`.
You passed {{bold}}%s{{/}} instead.`, nodeType, t), You passed {{bold}}%s{{/}} instead.`, nodeType, t),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
} }
func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t reflect.Type, cl CodeLocation) error {
mustGet := "{{bold}}func() []byte{{/}}, {{bold}}func(ctx SpecContext) []byte{{/}}, or {{bold}}func(ctx context.Context) []byte{{/}}, {{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}"
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its first function.
You passed {{bold}}%s{{/}} instead.`, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t reflect.Type, cl CodeLocation) error {
mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}, {{bold}}func([]byte){{/}}, {{bold}}func(ctx SpecContext, []byte){{/}}, or {{bold}}func(ctx context.Context, []byte){{/}}"
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its second function.
You passed {{bold}}%s{{/}} instead.`, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: "Multiple Functions", Heading: "Multiple Functions",
Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but more than one was passed in.`, nodeType), Message: formatter.F(`[%s] node must be passed a single function - but more than one was passed in.`, nodeType),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
@ -211,12 +255,30 @@ func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType)
func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: "Missing Functions", Heading: "Missing Functions",
Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but none was passed in.`, nodeType), Message: formatter.F(`[%s] node must be passed a single function - but none was passed in.`, nodeType),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
} }
func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextNode(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod",
Message: formatter.F(`[%s] was passed NodeTimeout, SpecTimeout, or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`, nodeType),
CodeLocation: cl,
DocLink: "spec-timeouts-and-interruptible-nodes",
}
}
func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextCleanupNode(cl CodeLocation) error {
return GinkgoError{
Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod",
Message: formatter.F(`[DeferCleanup] was passed NodeTimeout or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`),
CodeLocation: cl,
DocLink: "spec-timeouts-and-interruptible-nodes",
}
}
/* Ordered Container errors */ /* Ordered Container errors */
func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
@ -236,6 +298,15 @@ func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType N
} }
} }
func (g ginkgoErrors) InvalidContinueOnFailureDecoration(cl CodeLocation) error {
return GinkgoError{
Heading: "ContinueOnFailure not decorating an outermost Ordered Container",
Message: "ContinueOnFailure can only decorate an Ordered container, and this Ordered container must be the outermost Ordered container.",
CodeLocation: cl,
DocLink: "ordered-containers",
}
}
/* DeferCleanup errors */ /* DeferCleanup errors */
func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error { func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error {
return GinkgoError{ return GinkgoError{
@ -258,7 +329,7 @@ func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation)
func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType), Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType),
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.", Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a Reporting node.",
CodeLocation: cl, CodeLocation: cl,
DocLink: "cleaning-up-our-cleanup-code-defercleanup", DocLink: "cleaning-up-our-cleanup-code-defercleanup",
} }
@ -380,6 +451,15 @@ func (g ginkgoErrors) InvalidEntryDescription(cl CodeLocation) error {
} }
} }
func (g ginkgoErrors) MissingParametersForTableFunction(cl CodeLocation) error {
return GinkgoError{
Heading: fmt.Sprintf("No parameters have been passed to the Table Function"),
Message: fmt.Sprintf("The Table Function expected at least 1 parameter"),
CodeLocation: cl,
DocLink: "table-specs",
}
}
func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error { func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error {
return GinkgoError{ return GinkgoError{
Heading: "DescribeTable passed incorrect parameter type", Heading: "DescribeTable passed incorrect parameter type",
@ -498,6 +578,13 @@ func (g ginkgoErrors) DryRunInParallelConfiguration() error {
} }
} }
func (g ginkgoErrors) GracePeriodCannotBeZero() error {
return GinkgoError{
Heading: "Ginkgo requires a positive --grace-period.",
Message: "Please set --grace-period to a positive duration. The default is 30s.",
}
}
func (g ginkgoErrors) ConflictingVerbosityConfiguration() error { func (g ginkgoErrors) ConflictingVerbosityConfiguration() error {
return GinkgoError{ return GinkgoError{
Heading: "Conflicting reporter verbosity settings.", Heading: "Conflicting reporter verbosity settings.",

View file

@ -272,12 +272,23 @@ func tokenize(input string) func() (*treeNode, error) {
} }
} }
func MustParseLabelFilter(input string) LabelFilter {
filter, err := ParseLabelFilter(input)
if err != nil {
panic(err)
}
return filter
}
func ParseLabelFilter(input string) (LabelFilter, error) { func ParseLabelFilter(input string) (LabelFilter, error) {
if DEBUG_LABEL_FILTER_PARSING { if DEBUG_LABEL_FILTER_PARSING {
fmt.Println("\n==============") fmt.Println("\n==============")
fmt.Println("Input: ", input) fmt.Println("Input: ", input)
fmt.Print("Tokens: ") fmt.Print("Tokens: ")
} }
if input == "" {
return func(_ []string) bool { return true }, nil
}
nextToken := tokenize(input) nextToken := tokenize(input)
root := &treeNode{token: lfTokenRoot} root := &treeNode{token: lfTokenRoot}

View file

@ -6,8 +6,8 @@ import (
"time" "time"
) )
//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports // ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
//and across the network connection when running in parallel // and across the network connection when running in parallel
type ReportEntryValue struct { type ReportEntryValue struct {
raw interface{} //unexported to prevent gob from freaking out about unregistered structs raw interface{} //unexported to prevent gob from freaking out about unregistered structs
AsJSON string AsJSON string
@ -85,10 +85,12 @@ func (rev *ReportEntryValue) GobDecode(data []byte) error {
type ReportEntry struct { type ReportEntry struct {
// Visibility captures the visibility policy for this ReportEntry // Visibility captures the visibility policy for this ReportEntry
Visibility ReportEntryVisibility Visibility ReportEntryVisibility
// Time captures the time the AddReportEntry was called
Time time.Time
// Location captures the location of the AddReportEntry call // Location captures the location of the AddReportEntry call
Location CodeLocation Location CodeLocation
Time time.Time //need this for backwards compatibility
TimelineLocation TimelineLocation
// Name captures the name of this report // Name captures the name of this report
Name string Name string
// Value captures the (optional) object passed into AddReportEntry - this can be // Value captures the (optional) object passed into AddReportEntry - this can be
@ -120,7 +122,9 @@ func (entry ReportEntry) GetRawValue() interface{} {
return entry.Value.GetRawValue() return entry.Value.GetRawValue()
} }
func (entry ReportEntry) GetTimelineLocation() TimelineLocation {
return entry.TimelineLocation
}
type ReportEntries []ReportEntry type ReportEntries []ReportEntry

View file

@ -2,6 +2,8 @@ package types
import ( import (
"encoding/json" "encoding/json"
"fmt"
"sort"
"strings" "strings"
"time" "time"
) )
@ -56,19 +58,20 @@ type Report struct {
SuiteConfig SuiteConfig SuiteConfig SuiteConfig
//SpecReports is a list of all SpecReports generated by this test run //SpecReports is a list of all SpecReports generated by this test run
//It is empty when the SuiteReport is provided to ReportBeforeSuite
SpecReports SpecReports SpecReports SpecReports
} }
//PreRunStats contains a set of stats captured before the test run begins. This is primarily used // PreRunStats contains a set of stats captured before the test run begins. This is primarily used
//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) // by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. // and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
type PreRunStats struct { type PreRunStats struct {
TotalSpecs int TotalSpecs int
SpecsThatWillRun int SpecsThatWillRun int
} }
//Add is ued by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes // Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
//to form a complete final report. // to form a complete final report.
func (report Report) Add(other Report) Report { func (report Report) Add(other Report) Report {
report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded
@ -147,14 +150,24 @@ type SpecReport struct {
// ParallelProcess captures the parallel process that this spec ran on // ParallelProcess captures the parallel process that this spec ran on
ParallelProcess int ParallelProcess int
// RunningInParallel captures whether this spec is part of a suite that ran in parallel
RunningInParallel bool
//Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip()) //Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip())
//It includes detailed information about the Failure //It includes detailed information about the Failure
Failure Failure Failure Failure
// NumAttempts captures the number of times this Spec was run. Flakey specs can be retried with // NumAttempts captures the number of times this Spec was run.
// ginkgo --flake-attempts=N // Flakey specs can be retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator.
// Repeated specs can be retried with the use of the MustPassRepeatedly decorator
NumAttempts int NumAttempts int
// MaxFlakeAttempts captures whether the spec has been retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator.
MaxFlakeAttempts int
// MaxMustPassRepeatedly captures whether the spec has the MustPassRepeatedly decorator
MaxMustPassRepeatedly int
// CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter // CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter
CapturedGinkgoWriterOutput string CapturedGinkgoWriterOutput string
@ -168,6 +181,12 @@ type SpecReport struct {
// ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator // ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator
ProgressReports []ProgressReport ProgressReports []ProgressReport
// AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode.
AdditionalFailures []AdditionalFailure
// SpecEvents capture additional events that occur during the spec run
SpecEvents SpecEvents
} }
func (report SpecReport) MarshalJSON() ([]byte, error) { func (report SpecReport) MarshalJSON() ([]byte, error) {
@ -187,10 +206,14 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
ParallelProcess int ParallelProcess int
Failure *Failure `json:",omitempty"` Failure *Failure `json:",omitempty"`
NumAttempts int NumAttempts int
CapturedGinkgoWriterOutput string `json:",omitempty"` MaxFlakeAttempts int
CapturedStdOutErr string `json:",omitempty"` MaxMustPassRepeatedly int
ReportEntries ReportEntries `json:",omitempty"` CapturedGinkgoWriterOutput string `json:",omitempty"`
ProgressReports []ProgressReport `json:",omitempty"` CapturedStdOutErr string `json:",omitempty"`
ReportEntries ReportEntries `json:",omitempty"`
ProgressReports []ProgressReport `json:",omitempty"`
AdditionalFailures []AdditionalFailure `json:",omitempty"`
SpecEvents SpecEvents `json:",omitempty"`
}{ }{
ContainerHierarchyTexts: report.ContainerHierarchyTexts, ContainerHierarchyTexts: report.ContainerHierarchyTexts,
ContainerHierarchyLocations: report.ContainerHierarchyLocations, ContainerHierarchyLocations: report.ContainerHierarchyLocations,
@ -207,6 +230,8 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
Failure: nil, Failure: nil,
ReportEntries: nil, ReportEntries: nil,
NumAttempts: report.NumAttempts, NumAttempts: report.NumAttempts,
MaxFlakeAttempts: report.MaxFlakeAttempts,
MaxMustPassRepeatedly: report.MaxMustPassRepeatedly,
CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput, CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput,
CapturedStdOutErr: report.CapturedStdOutErr, CapturedStdOutErr: report.CapturedStdOutErr,
} }
@ -220,6 +245,12 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
if len(report.ProgressReports) > 0 { if len(report.ProgressReports) > 0 {
out.ProgressReports = report.ProgressReports out.ProgressReports = report.ProgressReports
} }
if len(report.AdditionalFailures) > 0 {
out.AdditionalFailures = report.AdditionalFailures
}
if len(report.SpecEvents) > 0 {
out.SpecEvents = report.SpecEvents
}
return json.Marshal(out) return json.Marshal(out)
} }
@ -237,13 +268,13 @@ func (report SpecReport) CombinedOutput() string {
return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput
} }
//Failed returns true if report.State is one of the SpecStateFailureStates // Failed returns true if report.State is one of the SpecStateFailureStates
// (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted) // (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted)
func (report SpecReport) Failed() bool { func (report SpecReport) Failed() bool {
return report.State.Is(SpecStateFailureStates) return report.State.Is(SpecStateFailureStates)
} }
//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText // FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
func (report SpecReport) FullText() string { func (report SpecReport) FullText() string {
texts := []string{} texts := []string{}
texts = append(texts, report.ContainerHierarchyTexts...) texts = append(texts, report.ContainerHierarchyTexts...)
@ -253,7 +284,7 @@ func (report SpecReport) FullText() string {
return strings.Join(texts, " ") return strings.Join(texts, " ")
} }
//Labels returns a deduped set of all the spec's Labels. // Labels returns a deduped set of all the spec's Labels.
func (report SpecReport) Labels() []string { func (report SpecReport) Labels() []string {
out := []string{} out := []string{}
seen := map[string]bool{} seen := map[string]bool{}
@ -275,7 +306,7 @@ func (report SpecReport) Labels() []string {
return out return out
} }
//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query // MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
func (report SpecReport) MatchesLabelFilter(query string) (bool, error) { func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
filter, err := ParseLabelFilter(query) filter, err := ParseLabelFilter(query)
if err != nil { if err != nil {
@ -284,29 +315,54 @@ func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
return filter(report.Labels()), nil return filter(report.Labels()), nil
} }
//FileName() returns the name of the file containing the spec // FileName() returns the name of the file containing the spec
func (report SpecReport) FileName() string { func (report SpecReport) FileName() string {
return report.LeafNodeLocation.FileName return report.LeafNodeLocation.FileName
} }
//LineNumber() returns the line number of the leaf node // LineNumber() returns the line number of the leaf node
func (report SpecReport) LineNumber() int { func (report SpecReport) LineNumber() int {
return report.LeafNodeLocation.LineNumber return report.LeafNodeLocation.LineNumber
} }
//FailureMessage() returns the failure message (or empty string if the test hasn't failed) // FailureMessage() returns the failure message (or empty string if the test hasn't failed)
func (report SpecReport) FailureMessage() string { func (report SpecReport) FailureMessage() string {
return report.Failure.Message return report.Failure.Message
} }
//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed) // FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
func (report SpecReport) FailureLocation() CodeLocation { func (report SpecReport) FailureLocation() CodeLocation {
return report.Failure.Location return report.Failure.Location
} }
// Timeline() returns a timeline view of the report
func (report SpecReport) Timeline() Timeline {
timeline := Timeline{}
if !report.Failure.IsZero() {
timeline = append(timeline, report.Failure)
if report.Failure.AdditionalFailure != nil {
timeline = append(timeline, *(report.Failure.AdditionalFailure))
}
}
for _, additionalFailure := range report.AdditionalFailures {
timeline = append(timeline, additionalFailure)
}
for _, reportEntry := range report.ReportEntries {
timeline = append(timeline, reportEntry)
}
for _, progressReport := range report.ProgressReports {
timeline = append(timeline, progressReport)
}
for _, specEvent := range report.SpecEvents {
timeline = append(timeline, specEvent)
}
sort.Sort(timeline)
return timeline
}
type SpecReports []SpecReport type SpecReports []SpecReport
//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes // WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports { func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
count := 0 count := 0
for i := range reports { for i := range reports {
@ -326,7 +382,7 @@ func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
return out return out
} }
//WithState returns the subset of SpecReports with State matching one of the requested SpecStates // WithState returns the subset of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) WithState(states SpecState) SpecReports { func (reports SpecReports) WithState(states SpecState) SpecReports {
count := 0 count := 0
for i := range reports { for i := range reports {
@ -345,7 +401,7 @@ func (reports SpecReports) WithState(states SpecState) SpecReports {
return out return out
} }
//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates // CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) CountWithState(states SpecState) int { func (reports SpecReports) CountWithState(states SpecState) int {
n := 0 n := 0
for i := range reports { for i := range reports {
@ -356,17 +412,75 @@ func (reports SpecReports) CountWithState(states SpecState) int {
return n return n
} }
//CountWithState returns the number of SpecReports that passed after multiple attempts // If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
func (reports SpecReports) CountOfFlakedSpecs() int { func (reports SpecReports) CountOfFlakedSpecs() int {
n := 0 n := 0
for i := range reports { for i := range reports {
if reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 { if reports[i].MaxFlakeAttempts > 1 && reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 {
n += 1 n += 1
} }
} }
return n return n
} }
// If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
func (reports SpecReports) CountOfRepeatedSpecs() int {
n := 0
for i := range reports {
if reports[i].MaxMustPassRepeatedly > 1 && reports[i].State.Is(SpecStateFailureStates) && reports[i].NumAttempts > 1 {
n += 1
}
}
return n
}
// TimelineLocation captures the location of an event in the spec's timeline
type TimelineLocation struct {
//Offset is the offset (in bytes) of the event relative to the GinkgoWriter stream
Offset int `json:",omitempty"`
//Order is the order of the event with respect to other events. The absolute value of Order
//is irrelevant. All that matters is that an event with a lower Order occurs before ane vent with a higher Order
Order int `json:",omitempty"`
Time time.Time
}
// TimelineEvent represent an event on the timeline
// consumers of Timeline will need to check the concrete type of each entry to determine how to handle it
type TimelineEvent interface {
GetTimelineLocation() TimelineLocation
}
type Timeline []TimelineEvent
func (t Timeline) Len() int { return len(t) }
func (t Timeline) Less(i, j int) bool {
return t[i].GetTimelineLocation().Order < t[j].GetTimelineLocation().Order
}
func (t Timeline) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t Timeline) WithoutHiddenReportEntries() Timeline {
out := Timeline{}
for _, event := range t {
if reportEntry, isReportEntry := event.(ReportEntry); isReportEntry && reportEntry.Visibility == ReportEntryVisibilityNever {
continue
}
out = append(out, event)
}
return out
}
func (t Timeline) WithoutVeryVerboseSpecEvents() Timeline {
out := Timeline{}
for _, event := range t {
if specEvent, isSpecEvent := event.(SpecEvent); isSpecEvent && specEvent.IsOnlyVisibleAtVeryVerbose() {
continue
}
out = append(out, event)
}
return out
}
// Failure captures failure information for an individual test // Failure captures failure information for an individual test
type Failure struct { type Failure struct {
// Message - the failure message passed into Fail(...). When using a matcher library // Message - the failure message passed into Fail(...). When using a matcher library
@ -379,6 +493,8 @@ type Failure struct {
// This CodeLocation will include a fully-populated StackTrace // This CodeLocation will include a fully-populated StackTrace
Location CodeLocation Location CodeLocation
TimelineLocation TimelineLocation
// ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked) // ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked)
// then ForwardedPanic will be populated with a string representation of the captured panic. // then ForwardedPanic will be populated with a string representation of the captured panic.
ForwardedPanic string `json:",omitempty"` ForwardedPanic string `json:",omitempty"`
@ -391,19 +507,29 @@ type Failure struct {
// FailureNodeType will contain the NodeType of the node in which the failure occurred. // FailureNodeType will contain the NodeType of the node in which the failure occurred.
// FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred. // FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred.
// If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred. // If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred.
FailureNodeContext FailureNodeContext FailureNodeContext FailureNodeContext `json:",omitempty"`
FailureNodeType NodeType
FailureNodeLocation CodeLocation FailureNodeType NodeType `json:",omitempty"`
FailureNodeContainerIndex int
FailureNodeLocation CodeLocation `json:",omitempty"`
FailureNodeContainerIndex int `json:",omitempty"`
//ProgressReport is populated if the spec was interrupted or timed out //ProgressReport is populated if the spec was interrupted or timed out
ProgressReport ProgressReport ProgressReport ProgressReport `json:",omitempty"`
//AdditionalFailure is non-nil if a follow-on failure occurred within the same node after the primary failure. This only happens when a node has timed out or been interrupted. In such cases the AdditionalFailure can include information about where/why the spec was stuck.
AdditionalFailure *AdditionalFailure `json:",omitempty"`
} }
func (f Failure) IsZero() bool { func (f Failure) IsZero() bool {
return f.Message == "" && (f.Location == CodeLocation{}) return f.Message == "" && (f.Location == CodeLocation{})
} }
func (f Failure) GetTimelineLocation() TimelineLocation {
return f.TimelineLocation
}
// FailureNodeContext captures the location context for the node containing the failing line of code // FailureNodeContext captures the location context for the node containing the failing line of code
type FailureNodeContext uint type FailureNodeContext uint
@ -434,6 +560,18 @@ func (fnc FailureNodeContext) MarshalJSON() ([]byte, error) {
return fncEnumSupport.MarshJSON(uint(fnc)) return fncEnumSupport.MarshJSON(uint(fnc))
} }
// AdditionalFailure capturs any additional failures that occur after the initial failure of a psec
// these typically occur in clean up nodes after the spec has failed.
// We can't simply use Failure as we want to track the SpecState to know what kind of failure this is
type AdditionalFailure struct {
State SpecState
Failure Failure
}
func (f AdditionalFailure) GetTimelineLocation() TimelineLocation {
return f.Failure.TimelineLocation
}
// SpecState captures the state of a spec // SpecState captures the state of a spec
// To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)` // To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)`
type SpecState uint type SpecState uint
@ -448,6 +586,7 @@ const (
SpecStateAborted SpecStateAborted
SpecStatePanicked SpecStatePanicked
SpecStateInterrupted SpecStateInterrupted
SpecStateTimedout
) )
var ssEnumSupport = NewEnumSupport(map[uint]string{ var ssEnumSupport = NewEnumSupport(map[uint]string{
@ -459,11 +598,15 @@ var ssEnumSupport = NewEnumSupport(map[uint]string{
uint(SpecStateAborted): "aborted", uint(SpecStateAborted): "aborted",
uint(SpecStatePanicked): "panicked", uint(SpecStatePanicked): "panicked",
uint(SpecStateInterrupted): "interrupted", uint(SpecStateInterrupted): "interrupted",
uint(SpecStateTimedout): "timedout",
}) })
func (ss SpecState) String() string { func (ss SpecState) String() string {
return ssEnumSupport.String(uint(ss)) return ssEnumSupport.String(uint(ss))
} }
func (ss SpecState) GomegaString() string {
return ssEnumSupport.String(uint(ss))
}
func (ss *SpecState) UnmarshalJSON(b []byte) error { func (ss *SpecState) UnmarshalJSON(b []byte) error {
out, err := ssEnumSupport.UnmarshJSON(b) out, err := ssEnumSupport.UnmarshJSON(b)
*ss = SpecState(out) *ss = SpecState(out)
@ -473,7 +616,7 @@ func (ss SpecState) MarshalJSON() ([]byte, error) {
return ssEnumSupport.MarshJSON(uint(ss)) return ssEnumSupport.MarshJSON(uint(ss))
} }
var SpecStateFailureStates = SpecStateFailed | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted var SpecStateFailureStates = SpecStateFailed | SpecStateTimedout | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted
func (ss SpecState) Is(states SpecState) bool { func (ss SpecState) Is(states SpecState) bool {
return ss&states != 0 return ss&states != 0
@ -481,35 +624,40 @@ func (ss SpecState) Is(states SpecState) bool {
// ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace // ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace
type ProgressReport struct { type ProgressReport struct {
ParallelProcess int Message string `json:",omitempty"`
RunningInParallel bool ParallelProcess int `json:",omitempty"`
RunningInParallel bool `json:",omitempty"`
Time time.Time ContainerHierarchyTexts []string `json:",omitempty"`
LeafNodeText string `json:",omitempty"`
LeafNodeLocation CodeLocation `json:",omitempty"`
SpecStartTime time.Time `json:",omitempty"`
ContainerHierarchyTexts []string CurrentNodeType NodeType `json:",omitempty"`
LeafNodeText string CurrentNodeText string `json:",omitempty"`
LeafNodeLocation CodeLocation CurrentNodeLocation CodeLocation `json:",omitempty"`
SpecStartTime time.Time CurrentNodeStartTime time.Time `json:",omitempty"`
CurrentNodeType NodeType CurrentStepText string `json:",omitempty"`
CurrentNodeText string CurrentStepLocation CodeLocation `json:",omitempty"`
CurrentNodeLocation CodeLocation CurrentStepStartTime time.Time `json:",omitempty"`
CurrentNodeStartTime time.Time
CurrentStepText string AdditionalReports []string `json:",omitempty"`
CurrentStepLocation CodeLocation
CurrentStepStartTime time.Time
CapturedGinkgoWriterOutput string `json:",omitempty"` CapturedGinkgoWriterOutput string `json:",omitempty"`
GinkgoWriterOffset int TimelineLocation TimelineLocation `json:",omitempty"`
Goroutines []Goroutine Goroutines []Goroutine `json:",omitempty"`
} }
func (pr ProgressReport) IsZero() bool { func (pr ProgressReport) IsZero() bool {
return pr.CurrentNodeType == NodeTypeInvalid return pr.CurrentNodeType == NodeTypeInvalid
} }
func (pr ProgressReport) Time() time.Time {
return pr.TimelineLocation.Time
}
func (pr ProgressReport) SpecGoroutine() Goroutine { func (pr ProgressReport) SpecGoroutine() Goroutine {
for _, goroutine := range pr.Goroutines { for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine { if goroutine.IsSpecGoroutine {
@ -547,6 +695,22 @@ func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport {
return out return out
} }
func (pr ProgressReport) WithoutOtherGoroutines() ProgressReport {
out := pr
filteredGoroutines := []Goroutine{}
for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine || goroutine.HasHighlights() {
filteredGoroutines = append(filteredGoroutines, goroutine)
}
}
out.Goroutines = filteredGoroutines
return out
}
func (pr ProgressReport) GetTimelineLocation() TimelineLocation {
return pr.TimelineLocation
}
type Goroutine struct { type Goroutine struct {
ID uint64 ID uint64
State string State string
@ -601,6 +765,7 @@ const (
NodeTypeReportBeforeEach NodeTypeReportBeforeEach
NodeTypeReportAfterEach NodeTypeReportAfterEach
NodeTypeReportBeforeSuite
NodeTypeReportAfterSuite NodeTypeReportAfterSuite
NodeTypeCleanupInvalid NodeTypeCleanupInvalid
@ -610,7 +775,9 @@ const (
) )
var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite
var ntEnumSupport = NewEnumSupport(map[uint]string{ var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeInvalid): "INVALID NODE TYPE", uint(NodeTypeInvalid): "INVALID NODE TYPE",
@ -628,9 +795,10 @@ var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite", uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite",
uint(NodeTypeReportBeforeEach): "ReportBeforeEach", uint(NodeTypeReportBeforeEach): "ReportBeforeEach",
uint(NodeTypeReportAfterEach): "ReportAfterEach", uint(NodeTypeReportAfterEach): "ReportAfterEach",
uint(NodeTypeReportBeforeSuite): "ReportBeforeSuite",
uint(NodeTypeReportAfterSuite): "ReportAfterSuite", uint(NodeTypeReportAfterSuite): "ReportAfterSuite",
uint(NodeTypeCleanupInvalid): "INVALID CLEANUP NODE", uint(NodeTypeCleanupInvalid): "DeferCleanup",
uint(NodeTypeCleanupAfterEach): "DeferCleanup", uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)",
uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)", uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)",
uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)", uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)",
}) })
@ -650,3 +818,99 @@ func (nt NodeType) MarshalJSON() ([]byte, error) {
func (nt NodeType) Is(nodeTypes NodeType) bool { func (nt NodeType) Is(nodeTypes NodeType) bool {
return nt&nodeTypes != 0 return nt&nodeTypes != 0
} }
/*
SpecEvent captures a vareity of events that can occur when specs run. See SpecEventType for the list of available events.
*/
type SpecEvent struct {
SpecEventType SpecEventType
CodeLocation CodeLocation
TimelineLocation TimelineLocation
Message string `json:",omitempty"`
Duration time.Duration `json:",omitempty"`
NodeType NodeType `json:",omitempty"`
Attempt int `json:",omitempty"`
}
func (se SpecEvent) GetTimelineLocation() TimelineLocation {
return se.TimelineLocation
}
func (se SpecEvent) IsOnlyVisibleAtVeryVerbose() bool {
return se.SpecEventType.Is(SpecEventByEnd | SpecEventNodeStart | SpecEventNodeEnd)
}
func (se SpecEvent) GomegaString() string {
out := &strings.Builder{}
out.WriteString("[" + se.SpecEventType.String() + " SpecEvent] ")
if se.Message != "" {
out.WriteString("Message=")
out.WriteString(`"` + se.Message + `",`)
}
if se.Duration != 0 {
out.WriteString("Duration=" + se.Duration.String() + ",")
}
if se.NodeType != NodeTypeInvalid {
out.WriteString("NodeType=" + se.NodeType.String() + ",")
}
if se.Attempt != 0 {
out.WriteString(fmt.Sprintf("Attempt=%d", se.Attempt) + ",")
}
out.WriteString("CL=" + se.CodeLocation.String() + ",")
out.WriteString(fmt.Sprintf("TL.Offset=%d", se.TimelineLocation.Offset))
return out.String()
}
type SpecEvents []SpecEvent
func (se SpecEvents) WithType(seType SpecEventType) SpecEvents {
out := SpecEvents{}
for _, event := range se {
if event.SpecEventType.Is(seType) {
out = append(out, event)
}
}
return out
}
type SpecEventType uint
const (
SpecEventInvalid SpecEventType = 0
SpecEventByStart SpecEventType = 1 << iota
SpecEventByEnd
SpecEventNodeStart
SpecEventNodeEnd
SpecEventSpecRepeat
SpecEventSpecRetry
)
var seEnumSupport = NewEnumSupport(map[uint]string{
uint(SpecEventInvalid): "INVALID SPEC EVENT",
uint(SpecEventByStart): "By",
uint(SpecEventByEnd): "By (End)",
uint(SpecEventNodeStart): "Node",
uint(SpecEventNodeEnd): "Node (End)",
uint(SpecEventSpecRepeat): "Repeat",
uint(SpecEventSpecRetry): "Retry",
})
func (se SpecEventType) String() string {
return seEnumSupport.String(uint(se))
}
func (se *SpecEventType) UnmarshalJSON(b []byte) error {
out, err := seEnumSupport.UnmarshJSON(b)
*se = SpecEventType(out)
return err
}
func (se SpecEventType) MarshalJSON() ([]byte, error) {
return seEnumSupport.MarshJSON(uint(se))
}
func (se SpecEventType) Is(specEventTypes SpecEventType) bool {
return se&specEventTypes != 0
}

View file

@ -1,3 +1,3 @@
package types package types
const VERSION = "2.2.0" const VERSION = "2.9.5"

View file

@ -5,28 +5,185 @@
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go) [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
## Guides This repository provides both a QUIC implementation, located in the `quic` package, as well as an HTTP/3 implementation, located in the `http3` package.
*We currently support Go 1.19.x and Go 1.20.x* ## Using QUIC
Running tests: ### Running a Server
go test ./... The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
### QUIC without HTTP/3 ```go
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
// ... error handling
tr := quic.Transport{
Conn: udpConn,
}
ln, err := tr.Listen(tlsConf, quicConf)
// ... error handling
go func() {
for {
conn, err := ln.Accept()
// ... error handling
// handle the connection, usually in a new Go routine
}
}
```
Take a look at [this echo example](example/echo/echo.go). The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
## Usage As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
```
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
```
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
### Running a Client
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
// ... error handling
```
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
```
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
### Using a QUIC Connection
#### Accepting Streams
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
```go
for {
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
// ... error handling
// handle the stream, usually in a new Go routine
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Opening Streams
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
```
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
```go
str, err := conn.OpenStream()
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// It's currently not possible to open another stream,
// but it might be possible later, once the peer allowed us to do so.
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Using Streams
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
### Configuring QUIC
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
### Closing a Connection
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
#### Initiated by the Application
A `quic.Connection` can be closed using `CloseWithError`:
```go
conn.CloseWithError(0x42, "error 0x42 occurred")
```
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
On the receiver side, this is surfaced as a `quic.ApplicationError`.
### QUIC Datagrams
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted.
Datagrams are sent using the `SendMessage` method on the `quic.Connection`:
```go
conn.SendMessage([]byte("foobar"))
```
And received using `ReceiveMessage`:
```go
msg, err := conn.ReceiveMessage()
```
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
## Using HTTP/3
### As a server ### As a server
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
```go ```go
http.Handle("/", http.FileServer(http.Dir(wwwDir))) http.Handle("/", http.FileServer(http.Dir(wwwDir)))
@ -59,6 +216,16 @@ http.Client{
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | | [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) |
If you'd like to see your project added to this list, please send us a PR.
## Release Policy
quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing ## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

View file

@ -51,18 +51,22 @@ func (b *packetBuffer) Release() {
} }
// Len returns the length of Data // Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount { func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
return protocol.ByteCount(len(b.Data)) func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
}
func (b *packetBuffer) putBack() { func (b *packetBuffer) putBack() {
if cap(b.Data) != int(protocol.MaxPacketBufferSize) { if cap(b.Data) == protocol.MaxPacketBufferSize {
panic("putPacketBuffer called with packet of wrong size!") bufferPool.Put(b)
return
} }
bufferPool.Put(b) if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
} }
var bufferPool sync.Pool var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer { func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer) buf := bufferPool.Get().(*packetBuffer)
@ -71,10 +75,18 @@ func getPacketBuffer() *packetBuffer {
return buf return buf
} }
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() any {
return &packetBuffer{ return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
Data: make([]byte, 0, protocol.MaxPacketBufferSize), }
} largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
} }
} }

View file

@ -43,7 +43,9 @@ type client struct {
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed. // It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See Dial for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) { func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
@ -61,7 +63,7 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi
} }
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. // DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed. // See DialAddr for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
@ -83,8 +85,8 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *
return conn, nil return conn, nil
} }
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. // DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See DialEarly for details. // See Dial for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false) dl, err := setupTransport(c, tlsConf, false)
if err != nil { if err != nil {
@ -98,12 +100,15 @@ func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tl
return conn, nil return conn, nil
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write // will be used instead of ReadFrom and WriteTo to read/write packets.
// packets.
// The tls.Config must define an application protocol (using NextProtos). // The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false) dl, err := setupTransport(c, tlsConf, false)
if err != nil { if err != nil {

View file

@ -16,13 +16,13 @@ type closedLocalConn struct {
perspective protocol.Perspective perspective protocol.Perspective
logger utils.Logger logger utils.Logger
sendPacket func(net.Addr, *packetInfo) sendPacket func(net.Addr, packetInfo)
} }
var _ packetHandler = &closedLocalConn{} var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it. // newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{ return &closedLocalConn{
sendPacket: sendPacket, sendPacket: sendPacket,
perspective: pers, perspective: pers,
@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe
} }
} }
func (c *closedLocalConn) handlePacket(p *receivedPacket) { func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++ c.counter++
// exponential backoff // exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers} return &closedRemoteConn{perspective: pers}
} }
func (s *closedRemoteConn) handlePacket(*receivedPacket) {} func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {} func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {} func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

View file

@ -1,13 +1,13 @@
package quic package quic
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
) )
// Clone clones a Config // Clone clones a Config
@ -24,11 +24,18 @@ func validateConfig(config *Config) error {
if config == nil { if config == nil {
return nil return nil
} }
if config.MaxIncomingStreams > 1<<60 { const maxStreams = 1 << 60
return errors.New("invalid value for Config.MaxIncomingStreams") if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
} }
if config.MaxIncomingUniStreams > 1<<60 { if config.MaxIncomingUniStreams > maxStreams {
return errors.New("invalid value for Config.MaxIncomingUniStreams") config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
} }
// check that all QUIC versions are actually supported // check that all QUIC versions are actually supported
for _, v := range config.Versions { for _, v := range config.Versions {

View file

@ -61,11 +61,6 @@ type cryptoStreamHandler interface {
ConnectionState() handshake.ConnectionState ConnectionState() handshake.ConnectionState
} }
type packetInfo struct {
addr net.IP
ifIndex uint32
}
type receivedPacket struct { type receivedPacket struct {
buffer *packetBuffer buffer *packetBuffer
@ -75,7 +70,7 @@ type receivedPacket struct {
ecn protocol.ECN ecn protocol.ECN
info *packetInfo info packetInfo // only valid if the contained IP address is valid
} }
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
@ -173,7 +168,7 @@ type connection struct {
oneRTTStream cryptoStream // only set for the server oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
sendingScheduled chan struct{} sendingScheduled chan struct{}
closeOnce sync.Once closeOnce sync.Once
@ -185,8 +180,8 @@ type connection struct {
handshakeCtx context.Context handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc handshakeCtxCancel context.CancelFunc
undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []*receivedPacket undecryptablePacketsToProcess []receivedPacket
clientHelloWritten <-chan *wire.TransportParameters clientHelloWritten <-chan *wire.TransportParameters
earlyConnReadyChan chan struct{} earlyConnReadyChan chan struct{}
@ -199,6 +194,7 @@ type connection struct {
versionNegotiated bool versionNegotiated bool
receivedFirstPacket bool receivedFirstPacket bool
// the minimum of the max_idle_timeout values advertised by both endpoints
idleTimeout time.Duration idleTimeout time.Duration
creationTime time.Time creationTime time.Time
// The idle timeout is set based on the max of the time we received the last packet... // The idle timeout is set based on the max of the time we received the last packet...
@ -297,6 +293,7 @@ var newConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
@ -353,7 +350,7 @@ var newConnection = func(
s.version, s.version,
) )
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s return s
@ -418,6 +415,7 @@ var newClientConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
@ -471,7 +469,7 @@ var newClientConnection = func(
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 { if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName s.tokenStoreKey = tlsConf.ServerName
} else { } else {
@ -512,7 +510,7 @@ func (s *connection) preSetup() {
s.perspective, s.perspective,
) )
s.framer = newFramer(s.streamsMap) s.framer = newFramer(s.streamsMap)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
@ -665,7 +663,7 @@ runLoop:
} else { } else {
idleTimeoutStartTime := s.idleTimeoutStartTime() idleTimeoutStartTime := s.idleTimeoutStartTime()
if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) ||
(s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { (s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) {
s.destroyImpl(qerr.ErrIdleTimeout) s.destroyImpl(qerr.ErrIdleTimeout)
continue continue
} }
@ -677,7 +675,7 @@ runLoop:
sendQueueAvailable = s.sendQueue.Available() sendQueueAvailable = s.sendQueue.Available()
continue continue
} }
if err := s.sendPackets(); err != nil { if err := s.triggerSending(); err != nil {
s.closeLocal(err) s.closeLocal(err)
} }
if s.sendQueue.WouldBlock() { if s.sendQueue.WouldBlock() {
@ -689,12 +687,12 @@ runLoop:
s.cryptoStreamHandler.Close() s.cryptoStreamHandler.Close()
<-handshaking <-handshaking
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr) s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
s.tracer.Close() s.tracer.Close()
} }
s.logger.Infof("Connection %s closed.", s.logID) s.logger.Infof("Connection %s closed.", s.logID)
s.sendQueue.Close()
s.timer.Stop() s.timer.Stop()
return closeErr.err return closeErr.err
} }
@ -723,13 +721,20 @@ func (s *connection) ConnectionState() ConnectionState {
return s.connState return s.connState
} }
// Time when the connection should time out
func (s *connection) nextIdleTimeoutTime() time.Time {
idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3)
return s.idleTimeoutStartTime().Add(idleTimeout)
}
// Time when the next keep-alive packet should be sent. // Time when the next keep-alive packet should be sent.
// It returns a zero time if no keep-alive should be sent. // It returns a zero time if no keep-alive should be sent.
func (s *connection) nextKeepAliveTime() time.Time { func (s *connection) nextKeepAliveTime() time.Time {
if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() {
return time.Time{} return time.Time{}
} }
return s.lastPacketReceivedTime.Add(s.keepAliveInterval) keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2)
return s.lastPacketReceivedTime.Add(keepAliveInterval)
} }
func (s *connection) maybeResetTimer() { func (s *connection) maybeResetTimer() {
@ -743,7 +748,7 @@ func (s *connection) maybeResetTimer() {
if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() {
deadline = keepAliveTime deadline = keepAliveTime
} else { } else {
deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) deadline = s.nextIdleTimeoutTime()
} }
} }
@ -800,25 +805,16 @@ func (s *connection) handleHandshakeConfirmed() {
s.sentPacketHandler.SetHandshakeConfirmed() s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.config.DisablePathMTUDiscovery { if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 { if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount maxPacketSize = protocol.MaxByteCount
} }
maxPacketSize = utils.Min(maxPacketSize, protocol.MaxPacketBufferSize) s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
getMaxPacketSize(s.conn.RemoteAddr()),
maxPacketSize,
func(size protocol.ByteCount) {
s.sentPacketHandler.SetMaxDatagramSize(size)
s.packer.SetMaxPacketSize(size)
},
)
} }
} }
func (s *connection) handlePacketImpl(rp *receivedPacket) bool { func (s *connection) handlePacketImpl(rp receivedPacket) bool {
s.sentPacketHandler.ReceivedBytes(rp.Size()) s.sentPacketHandler.ReceivedBytes(rp.Size())
if wire.IsVersionNegotiationPacket(rp.data) { if wire.IsVersionNegotiationPacket(rp.data) {
@ -834,7 +830,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
for len(data) > 0 { for len(data) > 0 {
var destConnID protocol.ConnectionID var destConnID protocol.ConnectionID
if counter > 0 { if counter > 0 {
p = p.Clone() p = *(p.Clone())
p.data = data p.data = data
var err error var err error
@ -907,7 +903,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
return processed return processed
} }
func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -958,7 +954,7 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID proto
return true return true
} }
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -1015,7 +1011,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
return true return true
} }
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err { switch err {
case handshake.ErrKeysDropped: case handshake.ErrKeysDropped:
if s.tracer != nil { if s.tracer != nil {
@ -1117,7 +1113,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
return true return true
} }
func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.tracer != nil { if s.tracer != nil {
@ -1261,7 +1257,11 @@ func (s *connection) handleFrames(
) (isAckEliciting bool, _ error) { ) (isAckEliciting bool, _ error) {
// Only used for tracing. // Only used for tracing.
// If we're not tracing, this slice will always remain empty. // If we're not tracing, this slice will always remain empty.
var frames []wire.Frame var frames []logging.Frame
if log != nil {
frames = make([]logging.Frame, 0, 4)
}
var handleErr error
for len(data) > 0 { for len(data) > 0 {
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
if err != nil { if err != nil {
@ -1274,27 +1274,27 @@ func (s *connection) handleFrames(
if ackhandler.IsFrameAckEliciting(frame) { if ackhandler.IsFrameAckEliciting(frame) {
isAckEliciting = true isAckEliciting = true
} }
// Only process frames now if we're not logging. if log != nil {
// If we're logging, we need to make sure that the packet_received event is logged first. frames = append(frames, logutils.ConvertFrame(frame))
if log == nil { }
if err := s.handleFrame(frame, encLevel, destConnID); err != nil { // An error occurred handling a previous frame.
// Don't handle the current frame.
if handleErr != nil {
continue
}
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
if log == nil {
return false, err return false, err
} }
} else { // If we're logging, we need to keep parsing (but not handling) all frames.
frames = append(frames, frame) handleErr = err
} }
} }
if log != nil { if log != nil {
fs := make([]logging.Frame, len(frames)) log(frames)
for i, frame := range frames { if handleErr != nil {
fs[i] = logutils.ConvertFrame(frame) return false, handleErr
}
log(fs)
for _, frame := range frames {
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return false, err
}
} }
} }
return return
@ -1310,7 +1310,6 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel
err = s.handleStreamFrame(frame) err = s.handleStreamFrame(frame)
case *wire.AckFrame: case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel) err = s.handleAckFrame(frame, encLevel)
wire.PutAckFrame(frame)
case *wire.ConnectionCloseFrame: case *wire.ConnectionCloseFrame:
s.handleConnectionCloseFrame(frame) s.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame: case *wire.ResetStreamFrame:
@ -1349,7 +1348,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel
} }
// handlePacket is called by the server with a new packet // handlePacket is called by the server with a new packet
func (s *connection) handlePacket(p *receivedPacket) { func (s *connection) handlePacket(p receivedPacket) {
// Discard packets once the amount of queued packets is larger than // Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxConnUnprocessedPackets // the channel size, protocol.MaxConnUnprocessedPackets
select { select {
@ -1723,7 +1722,6 @@ func (s *connection) applyTransportParameters() {
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
s.streamsMap.UpdateLimits(params) s.streamsMap.UpdateLimits(params)
s.packer.HandleTransportParameters(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.connFlowController.UpdateSendWindow(params.InitialMaxData)
s.rttStats.SetMaxAckDelay(params.MaxAckDelay) s.rttStats.SetMaxAckDelay(params.MaxAckDelay)
@ -1738,75 +1736,208 @@ func (s *connection) applyTransportParameters() {
} }
} }
func (s *connection) sendPackets() error { func (s *connection) triggerSending() error {
s.pacingDeadline = time.Time{} s.pacingDeadline = time.Time{}
now := time.Now()
var sentPacket bool // only used in for packets sent in send mode SendAny sendMode := s.sentPacketHandler.SendMode(now)
for { //nolint:exhaustive // No need to handle pacing limited here.
sendMode := s.sentPacketHandler.SendMode() switch sendMode {
if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { case ackhandler.SendAny:
deadline := s.sentPacketHandler.TimeUntilSend() return s.sendPackets(now)
if deadline.IsZero() { case ackhandler.SendNone:
deadline = deadlineSendImmediately return nil
} case ackhandler.SendPacingLimited:
s.pacingDeadline = deadline deadline := s.sentPacketHandler.TimeUntilSend()
// Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). if deadline.IsZero() {
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) deadline = deadlineSendImmediately
// sends enough ACKs to allow its peer to utilize the bandwidth.
if sentPacket {
return nil
}
sendMode = ackhandler.SendAck
} }
switch sendMode { s.pacingDeadline = deadline
case ackhandler.SendNone: // Allow sending of an ACK if we're pacing limit.
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate)
// sends enough ACKs to allow its peer to utilize the bandwidth.
fallthrough
case ackhandler.SendAck:
// We can at most send a single ACK only packet.
// There will only be a new ACK after receiving new packets.
// SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer.
return s.maybeSendAckOnlyPacket(now)
case ackhandler.SendPTOInitial:
if err := s.sendProbePacket(protocol.EncryptionInitial, now); err != nil {
return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil return nil
case ackhandler.SendAck: }
// If we already sent packets, and the send mode switches to SendAck, return s.triggerSending()
// as we've just become congestion limited. case ackhandler.SendPTOHandshake:
// There's no need to try to send an ACK at this moment. if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil {
if sentPacket { return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil
}
return s.triggerSending()
case ackhandler.SendPTOAppData:
if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil {
return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil
}
return s.triggerSending()
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
}
}
func (s *connection) sendPackets(now time.Time) error {
// Path MTU Discovery
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
// MTU probe packets per connection.
if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing()
p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version)
if err != nil {
return err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, buf.Len())
// This is kind of a hack. We need to trigger sending again somehow.
s.pacingDeadline = deadlineSendImmediately
return nil
}
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
}
s.windowUpdateQueue.QueueAll()
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil || packet == nil {
return err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
} else if sendMode == ackhandler.SendAny {
s.pacingDeadline = deadlineSendImmediately
}
return nil
}
if s.conn.capabilities().GSO {
return s.sendPacketsWithGSO(now)
}
return s.sendPacketsWithoutGSO(now)
}
func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil return nil
} }
// We can at most send a single ACK only packet. return err
// There will only be a new ACK after receiving new packets. }
// SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer.
return s.maybeSendAckOnlyPacket() s.sendQueue.Send(buf, buf.Len())
case ackhandler.SendPTOInitial:
if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { if s.sendQueue.WouldBlock() {
return err return nil
} }
case ackhandler.SendPTOHandshake: sendMode := s.sentPacketHandler.SendMode(now)
if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { if sendMode == ackhandler.SendPacingLimited {
return err s.resetPacingDeadline()
} return nil
case ackhandler.SendPTOAppData: }
if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { if sendMode != ackhandler.SendAny {
return err return nil
}
case ackhandler.SendAny:
sent, err := s.sendPacket()
if err != nil || !sent {
return err
}
sentPacket = true
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
} }
// Prioritize receiving of packets over sending out more packets. // Prioritize receiving of packets over sending out more packets.
if len(s.receivedPackets) > 0 { if len(s.receivedPackets) > 0 {
s.pacingDeadline = deadlineSendImmediately s.pacingDeadline = deadlineSendImmediately
return nil return nil
} }
if s.sendQueue.WouldBlock() {
return nil
}
} }
} }
func (s *connection) maybeSendAckOnlyPacket() error { func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
for {
var dontSendMore bool
size, err := s.appendPacket(buf, maxSize, now)
if err != nil {
if err != errNothingToPack {
return err
}
if buf.Len() == 0 {
buf.Release()
return nil
}
dontSendMore = true
}
if !dontSendMore {
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
}
if sendMode != ackhandler.SendAny {
dontSendMore = true
}
}
// Append another packet if
// 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet
// 3. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() {
continue
}
s.sendQueue.Send(buf, maxSize)
if dontSendMore {
return nil
}
if s.sendQueue.WouldBlock() {
return nil
}
// Prioritize receiving of packets over sending out more packets.
if len(s.receivedPackets) > 0 {
s.pacingDeadline = deadlineSendImmediately
return nil
}
buf = getLargePacketBuffer()
}
}
func (s *connection) resetPacingDeadline() {
deadline := s.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
s.pacingDeadline = deadline
}
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed { if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(true, s.version) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1817,20 +1948,20 @@ func (s *connection) maybeSendAckOnlyPacket() error {
return nil return nil
} }
now := time.Now() p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buffer, err := s.packer.PackPacket(true, now, s.version)
if err != nil { if err != nil {
if err == errNothingToPack { if err == errNothingToPack {
return nil return nil
} }
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now) s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, buf.Len())
return nil return nil
} }
func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time.Time) error {
// Queue probe packets until we actually send out a packet, // Queue probe packets until we actually send out a packet,
// or until there are no more packets to queue. // or until there are no more packets to queue.
var packet *coalescedPacket var packet *coalescedPacket
@ -1839,7 +1970,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
break break
} }
var err error var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1848,19 +1979,9 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
} }
} }
if packet == nil { if packet == nil {
//nolint:exhaustive // Cannot send probe packets for 0-RTT. s.retransmissionQueue.AddPing(encLevel)
switch encLevel {
case protocol.EncryptionInitial:
s.retransmissionQueue.AddInitial(&wire.PingFrame{})
case protocol.EncryptionHandshake:
s.retransmissionQueue.AddHandshake(&wire.PingFrame{})
case protocol.Encryption1RTT:
s.retransmissionQueue.AddAppData(&wire.PingFrame{})
default:
panic("unexpected encryption level")
}
var err error var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1868,55 +1989,35 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
} }
s.sendPackedCoalescedPacket(packet, time.Now()) s.sendPackedCoalescedPacket(packet, now)
return nil return nil
} }
func (s *connection) sendPacket() (bool, error) { // appendPacket appends a new packet to the given packetBuffer.
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { // If there was nothing to pack, the returned size is 0.
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) {
} startLen := buf.Len()
s.windowUpdateQueue.QueueAll() p, err := s.packer.AppendPacket(buf, maxSize, s.version)
now := time.Now()
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.version)
if err != nil || packet == nil {
return false, err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
return true, nil
} else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing()
p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now, s.version)
if err != nil {
return false, err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return true, nil
}
p, buffer, err := s.packer.PackPacket(false, now, s.version)
if err != nil { if err != nil {
if err == errNothingToPack { return 0, err
return false, nil
}
return false, err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) size := buf.Len() - startLen
s.sendPackedShortHeaderPacket(buffer, p.Packet, now) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false)
return true, nil s.registerPackedShortHeaderPacket(p, now)
return size, nil
} }
func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p) largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(buffer)
} }
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) {
@ -1925,16 +2026,24 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p.Packet) largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
} }
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer) s.sendQueue.Send(packet.buffer, packet.buffer.Len())
} }
func (s *connection) sendConnectionClose(e error) ([]byte, error) { func (s *connection) sendConnectionClose(e error) ([]byte, error) {
@ -1943,20 +2052,20 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) { if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.version) packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
} else if errors.As(e, &applicationErr) { } else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.version) packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
} else { } else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError, ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.version) }, s.mtuDiscoverer.CurrentSize(), s.version)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logCoalescedPacket(packet) s.logCoalescedPacket(packet)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data) return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len())
} }
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
@ -1988,7 +2097,8 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
func (s *connection) logShortHeaderPacket( func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame, ackFrame *wire.AckFrame,
frames []*ackhandler.Frame, frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber, pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
@ -2004,17 +2114,23 @@ func (s *connection) logShortHeaderPacket(
if ackFrame != nil { if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true) wire.LogFrame(s.logger, ackFrame, true)
} }
for _, frame := range frames { for _, f := range frames {
wire.LogFrame(s.logger, frame.Frame, true) wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
} }
} }
// tracing // tracing
if s.tracer != nil { if s.tracer != nil {
fs := make([]logging.Frame, 0, len(frames)) fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames { for _, f := range frames {
fs = append(fs, logutils.ConvertFrame(f.Frame)) fs = append(fs, logutils.ConvertFrame(f.Frame))
} }
for _, f := range streamFrames {
fs = append(fs, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame var ack *logging.AckFrame
if ackFrame != nil { if ackFrame != nil {
ack = logutils.ConvertAckFrame(ackFrame) ack = logutils.ConvertAckFrame(ackFrame)
@ -2042,6 +2158,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
packet.shortHdrPacket.DestConnID, packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack, packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames, packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase, packet.shortHdrPacket.KeyPhase,
@ -2060,7 +2177,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
s.logLongHeaderPacket(p) s.logLongHeaderPacket(p)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true)
} }
} }
@ -2121,7 +2238,7 @@ func (s *connection) scheduleSending() {
// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys.
// The logging.PacketType is only used for logging purposes. // The logging.PacketType is only used for logging purposes.
func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) {
if s.handshakeComplete { if s.handshakeComplete {
panic("shouldn't queue undecryptable packets after handshake completion") panic("shouldn't queue undecryptable packets after handshake completion")
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
) )
@ -14,10 +15,10 @@ type framer interface {
HasData() bool HasData() bool
QueueControlFrame(wire.Frame) QueueControlFrame(wire.Frame)
AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID) AddActiveStream(protocol.StreamID)
AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error Handle0RTTRejection() error
} }
@ -28,7 +29,7 @@ type framerI struct {
streamGetter streamGetter streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{} activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID streamQueue ringbuffer.RingBuffer[protocol.StreamID]
controlFrameMutex sync.Mutex controlFrameMutex sync.Mutex
controlFrames []wire.Frame controlFrames []wire.Frame
@ -45,7 +46,7 @@ func newFramer(streamGetter streamGetter) framer {
func (f *framerI) HasData() bool { func (f *framerI) HasData() bool {
f.mutex.Lock() f.mutex.Lock()
hasData := len(f.streamQueue) > 0 hasData := !f.streamQueue.Empty()
f.mutex.Unlock() f.mutex.Unlock()
if hasData { if hasData {
return true return true
@ -62,7 +63,7 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Unlock() f.controlFrameMutex.Unlock()
} }
func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount var length protocol.ByteCount
f.controlFrameMutex.Lock() f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 { for len(f.controlFrames) > 0 {
@ -71,9 +72,7 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco
if length+frameLen > maxLen { if length+frameLen > maxLen {
break break
} }
af := ackhandler.GetFrame() frames = append(frames, ackhandler.Frame{Frame: frame})
af.Frame = frame
frames = append(frames, af)
length += frameLen length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
} }
@ -84,24 +83,23 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco
func (f *framerI) AddActiveStream(id protocol.StreamID) { func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock() f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok { if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id) f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{} f.activeStreams[id] = struct{}{}
} }
f.mutex.Unlock() f.mutex.Unlock()
} }
func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount var length protocol.ByteCount
var lastFrame *ackhandler.Frame
f.mutex.Lock() f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue) numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ { for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen { if protocol.MinStreamFrameSize+length > maxLen {
break break
} }
id := f.streamQueue[0] id := f.streamQueue.PopFront()
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway. // This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there. // The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id) str, err := f.streamGetter.GetOrOpenSendStream(id)
@ -115,28 +113,27 @@ func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol
// Therefore, we can pretend to have more bytes available when popping // Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set). // the STREAM frame (which will always have the DataLen set).
remainingLen += quicvarint.Len(uint64(remainingLen)) remainingLen += quicvarint.Len(uint64(remainingLen))
frame, hasMoreData := str.popStreamFrame(remainingLen, v) frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
if hasMoreData { // put the stream back in the queue (at the end) if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id) f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active any more } else { // no more data to send. Stream is not active
delete(f.activeStreams, id) delete(f.activeStreams, id)
} }
// The frame can be nil // The frame can be "nil"
// * if the receiveStream was canceled after it said it had data // * if the receiveStream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame // * the remaining size doesn't allow us to add another STREAM frame
if frame == nil { if !ok {
continue continue
} }
frames = append(frames, frame) frames = append(frames, frame)
length += frame.Length(v) length += frame.Frame.Length(v)
lastFrame = frame
} }
f.mutex.Unlock() f.mutex.Unlock()
if lastFrame != nil { if len(frames) > startLen {
lastFrameLen := lastFrame.Length(v) l := frames[len(frames)-1].Frame.Length(v)
// account for the smaller size of the last STREAM frame // account for the smaller size of the last STREAM frame
lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false frames[len(frames)-1].Frame.DataLenPresent = false
length += lastFrame.Length(v) - lastFrameLen length += frames[len(frames)-1].Frame.Length(v) - l
} }
return frames, length return frames, length
} }
@ -146,7 +143,7 @@ func (f *framerI) Handle0RTTRejection() error {
defer f.mutex.Unlock() defer f.mutex.Unlock()
f.controlFrameMutex.Lock() f.controlFrameMutex.Lock()
f.streamQueue = f.streamQueue[:0] f.streamQueue.Clear()
for id := range f.activeStreams { for id := range f.activeStreams {
delete(f.activeStreams, id) delete(f.activeStreams, id)
} }

View file

@ -262,13 +262,26 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error {
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener. // and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener. // Closing the server does close the listener.
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln QUICEarlyListener) error { func (s *Server) ServeListener(ln QUICEarlyListener) error {
if err := s.addListener(&ln); err != nil { if err := s.addListener(&ln); err != nil {
return err return err
} }
err := s.serveListener(ln) defer s.removeListener(&ln)
s.removeListener(&ln) for {
return err conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
return http.ErrServerClosed
}
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
s.logger.Debugf(err.Error())
}
}()
}
} }
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
@ -310,26 +323,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
if err != nil { if err != nil {
return err return err
} }
if err := s.addListener(&ln); err != nil { return s.ServeListener(ln)
return err
}
err = s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveListener(ln QUICEarlyListener) error {
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
s.logger.Debugf(err.Error())
}
}()
}
} }
func extractPort(addr string) (int, error) { func extractPort(addr string) (int, error) {

View file

@ -276,17 +276,21 @@ type Config struct {
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm // If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxStreamReceiveWindow. // will increase the window up to MaxStreamReceiveWindow.
// If this value is zero, it will default to 512 KB. // If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialStreamReceiveWindow uint64 InitialStreamReceiveWindow uint64
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 6 MB. // If this value is zero, it will default to 6 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxStreamReceiveWindow uint64 MaxStreamReceiveWindow uint64
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm // If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxConnectionReceiveWindow. // will increase the window up to MaxConnectionReceiveWindow.
// If this value is zero, it will default to 512 KB. // If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialConnectionReceiveWindow uint64 InitialConnectionReceiveWindow uint64
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 15 MB. // If this value is zero, it will default to 15 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxConnectionReceiveWindow uint64 MaxConnectionReceiveWindow uint64
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts // AllowConnectionWindowIncrease is called every time the connection flow controller attempts
// to increase the connection flow control window. // to increase the connection flow control window.
@ -296,22 +300,23 @@ type Config struct {
// in this callback. // in this callback.
AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// Values above 2^60 are invalid.
// If not set, it will default to 100. // If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams. // If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingStreams int64 MaxIncomingStreams int64
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// Values above 2^60 are invalid.
// If not set, it will default to 100. // If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams. // If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingUniStreams int64 MaxIncomingUniStreams int64
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller). // every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration KeepAlivePeriod time.Duration
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Note that if Path MTU discovery is causing issues on your system, please open a new issue // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
DisablePathMTUDiscovery bool DisablePathMTUDiscovery bool
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
// This can be useful if version information is exchanged out-of-band. // This can be useful if version information is exchanged out-of-band.
@ -331,7 +336,13 @@ type ClientHelloInfo struct {
// ConnectionState records basic details about a QUIC connection // ConnectionState records basic details about a QUIC connection
type ConnectionState struct { type ConnectionState struct {
TLS handshake.ConnectionState // TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS handshake.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendMessage and ReceiveMessage methods on the Connection.
SupportsDatagrams bool SupportsDatagrams bool
Version VersionNumber // Version is the QUIC version of the QUIC connection.
Version VersionNumber
} }

View file

@ -10,7 +10,7 @@ func IsFrameAckEliciting(f wire.Frame) bool {
} }
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting. // HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
func HasAckElicitingFrames(fs []*Frame) bool { func HasAckElicitingFrames(fs []Frame) bool {
for _, f := range fs { for _, f := range fs {
if IsFrameAckEliciting(f.Frame) { if IsFrameAckEliciting(f.Frame) {
return true return true

View file

@ -1,29 +1,21 @@
package ackhandler package ackhandler
import ( import (
"sync"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
// FrameHandler handles the acknowledgement and the loss of a frame.
type FrameHandler interface {
OnAcked(wire.Frame)
OnLost(wire.Frame)
}
type Frame struct { type Frame struct {
wire.Frame // nil if the frame has already been acknowledged in another packet Frame wire.Frame // nil if the frame has already been acknowledged in another packet
OnLost func(wire.Frame) Handler FrameHandler
OnAcked func(wire.Frame)
} }
var framePool = sync.Pool{New: func() any { return &Frame{} }} type StreamFrame struct {
Frame *wire.StreamFrame
func GetFrame() *Frame { Handler FrameHandler
f := framePool.Get().(*Frame)
f.OnLost = nil
f.OnAcked = nil
return f
}
func putFrame(f *Frame) {
f.Frame = nil
f.OnLost = nil
f.OnAcked = nil
framePool.Put(f)
} }

View file

@ -10,20 +10,20 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(packet *Packet) SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool)
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) // ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount) ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
ResetForRetry() error ResetForRetry() error
SetHandshakeConfirmed() SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent. // The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode SendMode(now time.Time) SendMode
// TimeUntilSend is the time when the next packet should be sent. // TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets. // It is used for pacing packets.
TimeUntilSend() time.Time TimeUntilSend() time.Time
// HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment.
HasPacingBudget() bool
SetMaxDatagramSize(count protocol.ByteCount) SetMaxDatagramSize(count protocol.ByteCount)
// only to be called once the handshake is complete // only to be called once the handshake is complete

View file

@ -8,10 +8,11 @@ import (
) )
// A Packet is a packet // A Packet is a packet
type Packet struct { type packet struct {
SendTime time.Time SendTime time.Time
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
Frames []*Frame StreamFrames []StreamFrame
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
@ -23,15 +24,16 @@ type Packet struct {
skippedPacket bool skippedPacket bool
} }
func (p *Packet) outstanding() bool { func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
} }
var packetPool = sync.Pool{New: func() any { return &Packet{} }} var packetPool = sync.Pool{New: func() any { return &packet{} }}
func GetPacket() *Packet { func getPacket() *packet {
p := packetPool.Get().(*Packet) p := packetPool.Get().(*packet)
p.PacketNumber = 0 p.PacketNumber = 0
p.StreamFrames = nil
p.Frames = nil p.Frames = nil
p.LargestAcked = 0 p.LargestAcked = 0
p.Length = 0 p.Length = 0
@ -46,10 +48,8 @@ func GetPacket() *Packet {
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost). // We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. // This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
func putPacket(p *Packet) { func putPacket(p *packet) {
for _, f := range p.Frames {
putFrame(f)
}
p.Frames = nil p.Frames = nil
p.StreamFrames = nil
packetPool.Put(p) packetPool.Put(p)
} }

View file

@ -7,7 +7,10 @@ import (
type packetNumberGenerator interface { type packetNumberGenerator interface {
Peek() protocol.PacketNumber Peek() protocol.PacketNumber
Pop() protocol.PacketNumber // Pop pops the packet number.
// It reports if the packet number (before the one just popped) was skipped.
// It never skips more than one packet number in a row.
Pop() (skipped bool, _ protocol.PacketNumber)
} }
type sequentialPacketNumberGenerator struct { type sequentialPacketNumberGenerator struct {
@ -24,10 +27,10 @@ func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next return p.next
} }
func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next next := p.next
p.next++ p.next++
return next return false, next
} }
// The skippingPacketNumberGenerator generates the packet number for the next packet // The skippingPacketNumberGenerator generates the packet number for the next packet
@ -56,21 +59,26 @@ func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol
} }
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
if p.next == p.nextToSkip {
return p.next + 1
}
return p.next return p.next
} }
func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next next := p.next
p.next++ // generate a new packet number for the next packet
if p.next == p.nextToSkip { if p.next == p.nextToSkip {
p.next++ next++
p.next += 2
p.generateNewSkip() p.generateNewSkip()
return true, next
} }
return next p.next++ // generate a new packet number for the next packet
return false, next
} }
func (p *skippingPacketNumberGenerator) generateNewSkip() { func (p *skippingPacketNumberGenerator) generateNewSkip() {
// make sure that there are never two consecutive packet numbers that are skipped // make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
p.period = utils.Min(2*p.period, p.maxPeriod) p.period = utils.Min(2*p.period, p.maxPeriod)
} }

View file

@ -169,16 +169,18 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
} }
} }
ack := wire.GetAckFrame() // This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime))
ack.ECT0 = h.ect0 ack.ECT0 = h.ect0
ack.ECT1 = h.ect1 ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
if h.lastAck != nil {
wire.PutAckFrame(h.lastAck)
}
h.lastAck = ack h.lastAck = ack
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
h.ackQueued = false h.ackQueued = false

View file

@ -16,6 +16,10 @@ const (
SendPTOHandshake SendPTOHandshake
// SendPTOAppData means that an Application data probe packet should be sent // SendPTOAppData means that an Application data probe packet should be sent
SendPTOAppData SendPTOAppData
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
// but will do in a little while.
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
SendPacingLimited
// SendAny means that any packet should be sent // SendAny means that any packet should be sent
SendAny SendAny
) )
@ -34,6 +38,8 @@ func (s SendMode) String() string {
return "pto (Application Data)" return "pto (Application Data)"
case SendAny: case SendAny:
return "any" return "any"
case SendPacingLimited:
return "pacing limited"
default: default:
return fmt.Sprintf("invalid send mode: %d", s) return fmt.Sprintf("invalid send mode: %d", s)
} }

View file

@ -38,7 +38,7 @@ type packetNumberSpace struct {
largestSent protocol.PacketNumber largestSent protocol.PacketNumber
} }
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
var pns packetNumberGenerator var pns packetNumberGenerator
if skipPNs { if skipPNs {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
@ -46,7 +46,7 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStat
pns = newSequentialPacketNumberGenerator(initialPN) pns = newSequentialPacketNumberGenerator(initialPN)
} }
return &packetNumberSpace{ return &packetNumberSpace{
history: newSentPacketHistory(rttStats), history: newSentPacketHistory(),
pns: pns, pns: pns,
largestSent: protocol.InvalidPacketNumber, largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
@ -75,7 +75,7 @@ type sentPacketHandler struct {
// Only applies to the application-data packet number space. // Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber lowestNotConfirmedAcked protocol.PacketNumber
ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
bytesInFlight protocol.ByteCount bytesInFlight protocol.ByteCount
@ -125,9 +125,9 @@ func newSentPacketHandler(
return &sentPacketHandler{ return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false, rttStats), initialPackets: newPacketNumberSpace(initialPN, false),
handshakePackets: newPacketNumberSpace(0, false, rttStats), handshakePackets: newPacketNumberSpace(0, false),
appDataPackets: newPacketNumberSpace(0, true, rttStats), appDataPackets: newPacketNumberSpace(0, true),
rttStats: rttStats, rttStats: rttStats,
congestion: congestion, congestion: congestion,
perspective: pers, perspective: pers,
@ -146,7 +146,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
h.dropPackets(encLevel) h.dropPackets(encLevel)
} }
func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight { if p.Length > h.bytesInFlight {
panic("negative bytes_in_flight") panic("negative bytes_in_flight")
@ -165,7 +165,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight // remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.history.Iterate(func(p *Packet) (bool, error) { pnSpace.history.Iterate(func(p *packet) (bool, error) {
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
return true, nil return true, nil
}) })
@ -182,8 +182,8 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// and not when the client drops 0-RTT keys when the handshake completes. // and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid. // When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight. // Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption0RTT { if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false, nil return false, nil
} }
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
@ -228,26 +228,64 @@ func (h *sentPacketHandler) packetsInFlight() int {
return packetsInFlight return packetsInFlight
} }
func (h *sentPacketHandler) SentPacket(p *Packet) { func (h *sentPacketHandler) SentPacket(
h.bytesSent += p.Length t time.Time,
pn, largestAcked protocol.PacketNumber,
streamFrames []StreamFrame,
frames []Frame,
encLevel protocol.EncryptionLevel,
size protocol.ByteCount,
isPathMTUProbePacket bool,
) {
h.bytesSent += size
// For the client, drop the Initial packet number space when the first Handshake packet is sent. // For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && p.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial) h.dropPackets(protocol.EncryptionInitial)
} }
isAckEliciting := h.sentPacketImpl(p)
if isAckEliciting { pnSpace := h.getPacketNumberSpace(encLevel)
h.getPacketNumberSpace(p.EncryptionLevel).history.SentAckElicitingPacket(p) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
} else { for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ {
h.getPacketNumberSpace(p.EncryptionLevel).history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) h.logger.Debugf("Skipping packet number %d", p)
putPacket(p) }
p = nil //nolint:ineffassign // This is just to be on the safe side.
} }
if h.tracer != nil && isAckEliciting {
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
h.setLossDetectionTimer()
}
return
}
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.LargestAcked = largestAcked
p.StreamFrames = streamFrames
p.Frames = frames
p.IsPathMTUProbePacket = isPathMTUProbePacket
p.includedInBytesInFlight = true
pnSpace.history.SentAckElicitingPacket(p)
if h.tracer != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
if isAckEliciting || !h.peerCompletedAddressValidation { h.setLossDetectionTimer()
h.setLossDetectionTimer()
}
} }
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
@ -263,31 +301,6 @@ func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLev
} }
} }
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ {
pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
for p := utils.Max(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
}
pnSpace.largestSent = packet.PacketNumber
isAckEliciting := len(packet.Frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting)
return isAckEliciting
}
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
@ -361,7 +374,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
pnSpace.history.DeleteOldPackets(rcvTime)
h.setLossDetectionTimer() h.setLossDetectionTimer()
return acked1RTTPacket, nil return acked1RTTPacket, nil
} }
@ -371,13 +383,13 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
} }
// Packets are returned in ascending packet number order. // Packets are returned in ascending packet number order.
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
h.ackedPackets = h.ackedPackets[:0] h.ackedPackets = h.ackedPackets[:0]
ackRangeIndex := 0 ackRangeIndex := 0
lowestAcked := ack.LowestAcked() lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked() largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
// Ignore packets below the lowest acked // Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked { if p.PacketNumber < lowestAcked {
return true, nil return true, nil
@ -425,8 +437,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
} }
for _, f := range p.Frames { for _, f := range p.Frames {
if f.OnAcked != nil { if f.Handler != nil {
f.OnAcked(f.Frame) f.Handler.OnAcked(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
} }
} }
if err := pnSpace.history.Remove(p.PacketNumber); err != nil { if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
@ -587,30 +604,31 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
lostSendTime := now.Add(-lossDelay) lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *Packet) (bool, error) { return pnSpace.history.Iterate(func(p *packet) (bool, error) {
if p.PacketNumber > pnSpace.largestAcked { if p.PacketNumber > pnSpace.largestAcked {
return false, nil return false, nil
} }
if p.declaredLost || p.skippedPacket {
return true, nil
}
var packetLost bool var packetLost bool
if p.SendTime.Before(lostSendTime) { if p.SendTime.Before(lostSendTime) {
packetLost = true packetLost = true
if h.logger.Debug() { if !p.skippedPacket {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) if h.logger.Debug() {
} h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
if h.tracer != nil { }
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
}
} }
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true packetLost = true
if h.logger.Debug() { if !p.skippedPacket {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) if h.logger.Debug() {
} h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
if h.tracer != nil { }
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
}
} }
} else if pnSpace.lossTime.IsZero() { } else if pnSpace.lossTime.IsZero() {
// Note: This conditional is only entered once per call // Note: This conditional is only entered once per call
@ -621,12 +639,14 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
pnSpace.lossTime = lossTime pnSpace.lossTime = lossTime
} }
if packetLost { if packetLost {
p = pnSpace.history.DeclareLost(p) pnSpace.history.DeclareLost(p.PacketNumber)
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted if !p.skippedPacket {
h.removeFromBytesInFlight(p) // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.queueFramesForRetransmission(p) h.removeFromBytesInFlight(p)
if !p.IsPathMTUProbePacket { h.queueFramesForRetransmission(p)
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) if !p.IsPathMTUProbePacket {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
} }
} }
return true, nil return true, nil
@ -689,7 +709,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
h.ptoMode = SendPTOHandshake h.ptoMode = SendPTOHandshake
case protocol.Encryption1RTT: case protocol.Encryption1RTT:
// skip a packet number in order to elicit an immediate ACK // skip a packet number in order to elicit an immediate ACK
_ = h.PopPacketNumber(protocol.Encryption1RTT) pn := h.PopPacketNumber(protocol.Encryption1RTT)
h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn)
h.ptoMode = SendPTOAppData h.ptoMode = SendPTOAppData
default: default:
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
@ -703,23 +724,25 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
var lowestUnacked protocol.PacketNumber
if p := pnSpace.history.FirstOutstanding(); p != nil {
lowestUnacked = p.PacketNumber
} else {
lowestUnacked = pnSpace.largestAcked + 1
}
pn := pnSpace.pns.Peek() pn := pnSpace.pns.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) // See section 17.1 of RFC 9000.
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
} }
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
return h.getPacketNumberSpace(encLevel).pns.Pop() pnSpace := h.getPacketNumberSpace(encLevel)
skipped, pn := pnSpace.pns.Pop()
if skipped {
skippedPN := pn - 1
pnSpace.history.SkippedPacket(skippedPN)
if h.logger.Debug() {
h.logger.Debugf("Skipping packet number %d", skippedPN)
}
}
return pn
} }
func (h *sentPacketHandler) SendMode() SendMode { func (h *sentPacketHandler) SendMode(now time.Time) SendMode {
numTrackedPackets := h.appDataPackets.history.Len() numTrackedPackets := h.appDataPackets.history.Len()
if h.initialPackets != nil { if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len() numTrackedPackets += h.initialPackets.history.Len()
@ -758,6 +781,9 @@ func (h *sentPacketHandler) SendMode() SendMode {
} }
return SendAck return SendAck
} }
if !h.congestion.HasPacingBudget(now) {
return SendPacingLimited
}
return SendAny return SendAny
} }
@ -765,10 +791,6 @@ func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.congestion.TimeUntilSend(h.bytesInFlight) return h.congestion.TimeUntilSend(h.bytesInFlight)
} }
func (h *sentPacketHandler) HasPacingBudget() bool {
return h.congestion.HasPacingBudget()
}
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
h.congestion.SetMaxDatagramSize(s) h.congestion.SetMaxDatagramSize(s)
} }
@ -790,24 +812,32 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel)
// TODO: don't declare the packet lost here. // TODO: don't declare the packet lost here.
// Keep track of acknowledged frames instead. // Keep track of acknowledged frames instead.
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
pnSpace.history.DeclareLost(p) pnSpace.history.DeclareLost(p.PacketNumber)
return true return true
} }
func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
if len(p.Frames) == 0 { if len(p.Frames) == 0 && len(p.StreamFrames) == 0 {
panic("no frames") panic("no frames")
} }
for _, f := range p.Frames { for _, f := range p.Frames {
f.OnLost(f.Frame) if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
} }
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
p.StreamFrames = nil
p.Frames = nil p.Frames = nil
} }
func (h *sentPacketHandler) ResetForRetry() error { func (h *sentPacketHandler) ResetForRetry() error {
h.bytesInFlight = 0 h.bytesInFlight = 0
var firstPacketSendTime time.Time var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
if firstPacketSendTime.IsZero() { if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime firstPacketSendTime = p.SendTime
} }
@ -819,7 +849,7 @@ func (h *sentPacketHandler) ResetForRetry() error {
}) })
// All application data packets sent at this point are 0-RTT packets. // All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them. // In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket { if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p) h.queueFramesForRetransmission(p)
} }
@ -839,8 +869,8 @@ func (h *sentPacketHandler) ResetForRetry() error {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
} }
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
oldAlarm := h.alarm oldAlarm := h.alarm
h.alarm = time.Time{} h.alarm = time.Time{}
if h.tracer != nil { if h.tracer != nil {

View file

@ -2,162 +2,176 @@ package ackhandler
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
) )
type sentPacketHistory struct { type sentPacketHistory struct {
rttStats *utils.RTTStats packets []*packet
outstandingPacketList *list.List[*Packet]
etcPacketList *list.List[*Packet] numOutstanding int
packetMap map[protocol.PacketNumber]*list.Element[*Packet]
highestSent protocol.PacketNumber highestPacketNumber protocol.PacketNumber
} }
var packetElementPool sync.Pool func newSentPacketHistory() *sentPacketHistory {
func init() {
packetElementPool = *list.NewPool[*Packet]()
}
func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory {
return &sentPacketHistory{ return &sentPacketHistory{
rttStats: rttStats, packets: make([]*packet, 0, 32),
outstandingPacketList: list.NewWithPool[*Packet](&packetElementPool), highestPacketNumber: protocol.InvalidPacketNumber,
etcPacketList: list.NewWithPool[*Packet](&packetElementPool),
packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]),
highestSent: protocol.InvalidPacketNumber,
} }
} }
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
h.registerSentPacket(pn, encLevel, t) if h.highestPacketNumber != protocol.InvalidPacketNumber {
if pn != h.highestPacketNumber+1 {
panic("non-sequential packet number use")
}
}
} }
func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) { func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.registerSentPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
})
}
var el *list.Element[*Packet] func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
}
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.highestPacketNumber = p.PacketNumber
h.packets = append(h.packets, p)
if p.outstanding() { if p.outstanding() {
el = h.outstandingPacketList.PushBack(p) h.numOutstanding++
} else {
el = h.etcPacketList.PushBack(p)
} }
h.packetMap[p.PacketNumber] = el
}
func (h *sentPacketHistory) registerSentPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) {
if pn <= h.highestSent {
panic("non-sequential packet number use")
}
// Skipped packet numbers.
for p := h.highestSent + 1; p < pn; p++ {
el := h.etcPacketList.PushBack(&Packet{
PacketNumber: p,
EncryptionLevel: encLevel,
SendTime: t,
skippedPacket: true,
})
h.packetMap[p] = el
}
h.highestSent = pn
} }
// Iterate iterates through all packets. // Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
cont := true for _, p := range h.packets {
outstandingEl := h.outstandingPacketList.Front() if p == nil {
etcEl := h.etcPacketList.Front() continue
var el *list.Element[*Packet]
// whichever has the next packet number is returned first
for cont {
if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) {
el = etcEl
} else {
el = outstandingEl
} }
if el == nil { cont, err := cb(p)
return nil
}
if el == outstandingEl {
outstandingEl = outstandingEl.Next()
} else {
etcEl = etcEl.Next()
}
var err error
cont, err = cb(el.Value)
if err != nil { if err != nil {
return err return err
} }
if !cont {
return nil
}
} }
return nil return nil
} }
// FirstOutstanding returns the first outstanding packet. // FirstOutstanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *Packet { func (h *sentPacketHistory) FirstOutstanding() *packet {
el := h.outstandingPacketList.Front() if !h.HasOutstandingPackets() {
if el == nil {
return nil return nil
} }
return el.Value for _, p := range h.packets {
} if p != nil && p.outstanding() {
return p
func (h *sentPacketHistory) Len() int { }
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
} }
el.List().Remove(el)
delete(h.packetMap, p)
return nil return nil
} }
func (h *sentPacketHistory) HasOutstandingPackets() bool { func (h *sentPacketHistory) Len() int {
return h.outstandingPacketList.Len() > 0 return len(h.packets)
} }
func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
maxAge := 3 * h.rttStats.PTO(false) idx, ok := h.getIndex(pn)
var nextEl *list.Element[*Packet]
// we don't iterate outstandingPacketList, as we should not delete outstanding packets.
// being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes.
for el := h.etcPacketList.Front(); el != nil; el = nextEl {
nextEl = el.Next()
p := el.Value
if p.SendTime.After(now.Add(-maxAge)) {
break
}
delete(h.packetMap, p.PacketNumber)
h.etcPacketList.Remove(el)
}
}
func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet {
el, ok := h.packetMap[p.PacketNumber]
if !ok { if !ok {
return nil return fmt.Errorf("packet %d not found in sent packet history", pn)
} }
el.List().Remove(el) p := h.packets[idx]
p.declaredLost = true if p.outstanding() {
// move it to the correct position in the etc list (based on the packet number) h.numOutstanding--
for el = h.etcPacketList.Back(); el != nil; el = el.Prev() { if h.numOutstanding < 0 {
if el.Value.PacketNumber < p.PacketNumber { panic("negative number of outstanding packets")
break
} }
} }
if el == nil { h.packets[idx] = nil
el = h.etcPacketList.PushFront(p) // clean up all skipped packets directly before this packet number
} else { for idx > 0 {
el = h.etcPacketList.InsertAfter(p, el) idx--
p := h.packets[idx]
if p == nil || !p.skippedPacket {
break
}
h.packets[idx] = nil
}
if idx == 0 {
h.cleanupStart()
}
if len(h.packets) > 0 && h.packets[0] == nil {
panic("remove failed")
}
return nil
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
return 0, false
}
first := h.packets[0].PacketNumber
if p < first {
return 0, false
}
index := int(p - first)
if index > len(h.packets)-1 {
return 0, false
}
return index, true
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {
if p != nil {
h.packets = h.packets[i:]
return
}
}
h.packets = h.packets[:0]
}
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
if len(h.packets) == 0 {
return protocol.InvalidPacketNumber
}
return h.packets[0].PacketNumber
}
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
idx, ok := h.getIndex(pn)
if !ok {
return
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
if idx == 0 {
h.cleanupStart()
} }
h.packetMap[p.PacketNumber] = el
return el.Value
} }

View file

@ -120,8 +120,8 @@ func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
return c.pacer.TimeUntilSend() return c.pacer.TimeUntilSend()
} }
func (c *cubicSender) HasPacingBudget() bool { func (c *cubicSender) HasPacingBudget(now time.Time) bool {
return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize return c.pacer.Budget(now) >= c.maxDatagramSize
} }
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {

View file

@ -9,7 +9,7 @@ import (
// A SendAlgorithm performs congestion control // A SendAlgorithm performs congestion control
type SendAlgorithm interface { type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
HasPacingBudget() bool HasPacingBudget(now time.Time) bool
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
CanSend(bytesInFlight protocol.ByteCount) bool CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart() MaybeExitSlowStart()

View file

@ -59,7 +59,10 @@ type StatelessResetToken [16]byte
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. // Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
const MaxPacketBufferSize ByteCount = 1452 const MaxPacketBufferSize = 1452
// MaxLargePacketBufferSize is used when using GSO
const MaxLargePacketBufferSize = 20 * 1024
// MinInitialPacketSize is the minimum size an Initial packet is required to have. // MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200 const MinInitialPacketSize = 1200

View file

@ -0,0 +1,86 @@
package ringbuffer
// A RingBuffer is a ring buffer.
// It acts as a heap that doesn't cause any allocations.
type RingBuffer[T any] struct {
ring []T
headPos, tailPos int
full bool
}
// Init preallocs a buffer with a certain size.
func (r *RingBuffer[T]) Init(size int) {
r.ring = make([]T, size)
}
// Len returns the number of elements in the ring buffer.
func (r *RingBuffer[T]) Len() int {
if r.full {
return len(r.ring)
}
if r.tailPos >= r.headPos {
return r.tailPos - r.headPos
}
return r.tailPos - r.headPos + len(r.ring)
}
// Empty says if the ring buffer is empty.
func (r *RingBuffer[T]) Empty() bool {
return !r.full && r.headPos == r.tailPos
}
// PushBack adds a new element.
// If the ring buffer is full, its capacity is increased first.
func (r *RingBuffer[T]) PushBack(t T) {
if r.full || len(r.ring) == 0 {
r.grow()
}
r.ring[r.tailPos] = t
r.tailPos++
if r.tailPos == len(r.ring) {
r.tailPos = 0
}
if r.tailPos == r.headPos {
r.full = true
}
}
// PopFront returns the next element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) PopFront() T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue")
}
r.full = false
t := r.ring[r.headPos]
r.ring[r.headPos] = *new(T)
r.headPos++
if r.headPos == len(r.ring) {
r.headPos = 0
}
return t
}
// Grow the maximum size of the queue.
// This method assume the queue is full.
func (r *RingBuffer[T]) grow() {
oldRing := r.ring
newSize := len(oldRing) * 2
if newSize == 0 {
newSize = 1
}
r.ring = make([]T, newSize)
headLen := copy(r.ring, oldRing[r.headPos:])
copy(r.ring[headLen:], oldRing[:r.headPos])
r.headPos, r.tailPos, r.full = 0, len(oldRing), false
}
// Clear removes all elements.
func (r *RingBuffer[T]) Clear() {
var zeroValue T
for i := range r.ring {
r.ring[i] = zeroValue
}
r.headPos, r.tailPos, r.full = 0, 0, false
}

View file

@ -22,19 +22,17 @@ type AckFrame struct {
} }
// parseAckFrame reads an ACK frame // parseAckFrame reads an ACK frame
func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error {
ecn := typ == ackECNFrameType ecn := typ == ackECNFrameType
frame := GetAckFrame()
la, err := quicvarint.Read(r) la, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
largestAcked := protocol.PacketNumber(la) largestAcked := protocol.PacketNumber(la)
delay, err := quicvarint.Read(r) delay, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
@ -46,17 +44,17 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc
numBlocks, err := quicvarint.Read(r) numBlocks, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
// read the first ACK range // read the first ACK range
ab, err := quicvarint.Read(r) ab, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
ackBlock := protocol.PacketNumber(ab) ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked { if ackBlock > largestAcked {
return nil, errors.New("invalid first ACK range") return errors.New("invalid first ACK range")
} }
smallest := largestAcked - ackBlock smallest := largestAcked - ackBlock
@ -65,50 +63,50 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc
for i := uint64(0); i < numBlocks; i++ { for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r) g, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
gap := protocol.PacketNumber(g) gap := protocol.PacketNumber(g)
if smallest < gap+2 { if smallest < gap+2 {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
largest := smallest - gap - 2 largest := smallest - gap - 2
ab, err := quicvarint.Read(r) ab, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
ackBlock := protocol.PacketNumber(ab) ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest { if ackBlock > largest {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
smallest = largest - ackBlock smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
} }
if !frame.validateAckRanges() { if !frame.validateAckRanges() {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
if ecn { if ecn {
ect0, err := quicvarint.Read(r) ect0, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
frame.ECT0 = ect0 frame.ECT0 = ect0
ect1, err := quicvarint.Read(r) ect1, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
frame.ECT1 = ect1 frame.ECT1 = ect1
ecnce, err := quicvarint.Read(r) ecnce, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
frame.ECNCE = ecnce frame.ECNCE = ecnce
} }
return frame, nil return nil
} }
// Append appends an ACK frame. // Append appends an ACK frame.
@ -251,6 +249,18 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
return p <= f.AckRanges[i].Largest return p <= f.AckRanges[i].Largest
} }
func (f *AckFrame) Reset() {
f.DelayTime = 0
f.ECT0 = 0
f.ECT1 = 0
f.ECNCE = 0
for _, r := range f.AckRanges {
r.Largest = 0
r.Smallest = 0
}
f.AckRanges = f.AckRanges[:0]
}
func encodeAckDelay(delay time.Duration) uint64 { func encodeAckDelay(delay time.Duration) uint64 {
return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
} }

View file

@ -1,24 +0,0 @@
package wire
import "sync"
var ackFramePool = sync.Pool{New: func() any {
return &AckFrame{}
}}
func GetAckFrame() *AckFrame {
f := ackFramePool.Get().(*AckFrame)
f.AckRanges = f.AckRanges[:0]
f.ECNCE = 0
f.ECT0 = 0
f.ECT1 = 0
f.DelayTime = 0
return f
}
func PutAckFrame(f *AckFrame) {
if cap(f.AckRanges) > 4 {
return
}
ackFramePool.Put(f)
}

View file

@ -39,9 +39,12 @@ const (
type frameParser struct { type frameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8 ackDelayExponent uint8
supportsDatagrams bool supportsDatagrams bool
// To avoid allocating when parsing, keep a single ACK frame struct.
// It is used over and over again.
ackFrame *AckFrame
} }
var _ FrameParser = &frameParser{} var _ FrameParser = &frameParser{}
@ -51,6 +54,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
return &frameParser{ return &frameParser{
r: *bytes.NewReader(nil), r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams, supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
} }
} }
@ -105,7 +109,9 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
if encLevel != protocol.Encryption1RTT { if encLevel != protocol.Encryption1RTT {
ackDelayExponent = protocol.DefaultAckDelayExponent ackDelayExponent = protocol.DefaultAckDelayExponent
} }
frame, err = parseAckFrame(r, typ, ackDelayExponent, v) p.ackFrame.Reset()
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
frame = p.ackFrame
case resetStreamFrameType: case resetStreamFrameType:
frame, err = parseResetStreamFrame(r, v) frame, err = parseResetStreamFrame(r, v)
case stopSendingFrameType: case stopSendingFrameType:

View file

@ -2,14 +2,13 @@ package wire
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"sort" "sort"
"sync"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -26,15 +25,6 @@ var AdditionalTransportParametersClient map[uint64][]byte
const transportParameterMarshalingVersion = 1 const transportParameterMarshalingVersion = 1
var (
randomMutex sync.Mutex
random rand.Rand
)
func init() {
random = *rand.New(rand.NewSource(time.Now().UnixNano()))
}
type transportParameterID uint64 type transportParameterID uint64
const ( const (
@ -341,13 +331,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b := make([]byte, 0, 256) b := make([]byte, 0, 256)
// add a greased value // add a greased value
b = quicvarint.Append(b, uint64(27+31*rand.Intn(100))) random := make([]byte, 18)
randomMutex.Lock() rand.Read(random)
length := random.Intn(16) b = quicvarint.Append(b, 27+31*uint64(random[0]))
length := random[1] % 16
b = quicvarint.Append(b, uint64(length)) b = quicvarint.Append(b, uint64(length))
b = b[:len(b)+length] b = append(b, random[2:2+length]...)
random.Read(b[len(b)-length:])
randomMutex.Unlock()
// initial_max_stream_data_bidi_local // initial_max_stream_data_bidi_local
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"net"
"time" "time"
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
@ -10,7 +11,11 @@ import (
) )
type mtuDiscoverer interface { type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
ShouldSendProbe(now time.Time) bool ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
} }
@ -22,25 +27,38 @@ const (
mtuProbeDelay = 5 mtuProbeDelay = 5
) )
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
type mtuFinder struct { type mtuFinder struct {
lastProbeTime time.Time lastProbeTime time.Time
probeInFlight bool
mtuIncreased func(protocol.ByteCount) mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
} }
var _ mtuDiscoverer = &mtuFinder{} var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
return &mtuFinder{ return &mtuFinder{
current: start, inFlight: protocol.InvalidByteCount,
rttStats: rttStats, current: start,
lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately rttStats: rttStats,
mtuIncreased: mtuIncreased, mtuIncreased: mtuIncreased,
max: max,
} }
} }
@ -48,8 +66,16 @@ func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1 return f.max-f.current <= maxMTUDiff+1
} }
func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.probeInFlight || f.done() { if f.max == 0 || f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
return false return false
} }
return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT()))
@ -58,17 +84,36 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
size := (f.max + f.current) / 2 size := (f.max + f.current) / 2
f.lastProbeTime = time.Now() f.lastProbeTime = time.Now()
f.probeInFlight = true f.inFlight = size
return ackhandler.Frame{ return ackhandler.Frame{
Frame: &wire.PingFrame{}, Frame: &wire.PingFrame{},
OnLost: func(wire.Frame) { Handler: (*mtuFinderAckHandler)(f),
f.probeInFlight = false
f.max = size
},
OnAcked: func(wire.Frame) {
f.probeInFlight = false
f.current = size
f.mtuIncreased(size)
},
}, size }, size
} }
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
}
type mtuFinderAckHandler mtuFinder
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.current = size
h.mtuIncreased(size)
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.max = size
h.inFlight = protocol.InvalidByteCount
}

View file

@ -15,23 +15,35 @@ import (
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
) )
type connCapabilities struct {
// This connection has the Don't Fragment (DF) bit set.
// This means it makes to run DPLPMTUD.
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
}
// rawConn is a connection that allow reading of a receivedPackeh. // rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface { type rawConn interface {
ReadPacket() (*receivedPacket, error) ReadPacket() (receivedPacket, error)
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) // The size parameter is used for GSO.
// If GSO is not support, len(b) must be equal to size.
WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error)
LocalAddr() net.Addr LocalAddr() net.Addr
SetReadDeadline(time.Time) error SetReadDeadline(time.Time) error
io.Closer io.Closer
capabilities() connCapabilities
} }
type closePacket struct { type closePacket struct {
payload []byte payload []byte
addr net.Addr addr net.Addr
info *packetInfo info packetInfo
} }
type unknownPacketHandler interface { type unknownPacketHandler interface {
handlePacket(*receivedPacket) handlePacket(receivedPacket)
setCloseError(error) setCloseError(error)
} }
@ -165,7 +177,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
var handler packetHandler var handler packetHandler
if connClosePacket != nil { if connClosePacket != nil {
handler = newClosedLocalConn( handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) { func(addr net.Addr, info packetInfo) {
h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
}, },
pers, pers,

View file

@ -3,30 +3,25 @@ package quic
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
var errNothingToPack = errors.New("nothing to pack") var errNothingToPack = errors.New("nothing to pack")
type packer interface { type packer interface {
PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.VersionNumber) (*coalescedPacket, error) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.VersionNumber) (*coalescedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.VersionNumber) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
SetMaxPacketSize(protocol.ByteCount)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
HandleTransportParameters(*wire.TransportParameters)
SetToken([]byte) SetToken([]byte)
} }
@ -35,26 +30,31 @@ type sealer interface {
} }
type payload struct { type payload struct {
frames []*ackhandler.Frame streamFrames []ackhandler.StreamFrame
ack *wire.AckFrame frames []ackhandler.Frame
length protocol.ByteCount ack *wire.AckFrame
length protocol.ByteCount
} }
type longHeaderPacket struct { type longHeaderPacket struct {
header *wire.ExtendedHeader header *wire.ExtendedHeader
ack *wire.AckFrame ack *wire.AckFrame
frames []*ackhandler.Frame frames []ackhandler.Frame
streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets
length protocol.ByteCount length protocol.ByteCount
isMTUProbePacket bool
} }
type shortHeaderPacket struct { type shortHeaderPacket struct {
*ackhandler.Packet PacketNumber protocol.PacketNumber
Frames []ackhandler.Frame
StreamFrames []ackhandler.StreamFrame
Ack *wire.AckFrame
Length protocol.ByteCount
IsPathMTUProbePacket bool
// used for logging // used for logging
DestConnID protocol.ConnectionID DestConnID protocol.ConnectionID
Ack *wire.AckFrame
PacketNumberLen protocol.PacketNumberLen PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit KeyPhase protocol.KeyPhaseBit
} }
@ -83,52 +83,6 @@ func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) }
func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet {
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
encLevel := p.EncryptionLevel()
for i := range p.frames {
if p.frames[i].OnLost != nil {
continue
}
//nolint:exhaustive // Short header packets are handled separately.
switch encLevel {
case protocol.EncryptionInitial:
p.frames[i].OnLost = q.AddInitial
case protocol.EncryptionHandshake:
p.frames[i].OnLost = q.AddHandshake
case protocol.Encryption0RTT:
p.frames[i].OnLost = q.AddAppData
}
}
ap := ackhandler.GetPacket()
ap.PacketNumber = p.header.PacketNumber
ap.LargestAcked = largestAcked
ap.Frames = p.frames
ap.Length = p.length
ap.EncryptionLevel = encLevel
ap.SendTime = now
ap.IsPathMTUProbePacket = p.isMTUProbePacket
return ap
}
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
type packetNumberManager interface { type packetNumberManager interface {
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
@ -143,8 +97,8 @@ type sealingManager interface {
type frameSource interface { type frameSource interface {
HasData() bool HasData() bool
AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
} }
type ackFrameSource interface { type ackFrameSource interface {
@ -169,13 +123,23 @@ type packetPacker struct {
datagramQueue *datagramQueue datagramQueue *datagramQueue
retransmissionQueue *retransmissionQueue retransmissionQueue *retransmissionQueue
maxPacketSize protocol.ByteCount
numNonAckElicitingAcks int numNonAckElicitingAcks int
} }
var _ packer = &packetPacker{} var _ packer = &packetPacker{}
func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective) *packetPacker { func newPacketPacker(
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream, handshakeStream cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
cryptoSetup sealingManager,
framer frameSource,
acks ackFrameSource,
datagramQueue *datagramQueue,
perspective protocol.Perspective,
) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
getDestConnID: getDestConnID, getDestConnID: getDestConnID,
@ -188,23 +152,22 @@ func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() proto
framer: framer, framer: framer,
acks: acks, acks: acks,
pnManager: packetNumberManager, pnManager: packetNumberManager,
maxPacketSize: getMaxPacketSize(remoteAddr),
} }
} }
// PackConnectionClose packs a packet that closes the connection with a transport error. // PackConnectionClose packs a packet that closes the connection with a transport error.
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
var reason string var reason string
// don't send details of crypto errors // don't send details of crypto errors
if !e.ErrorCode.IsCryptoError() { if !e.ErrorCode.IsCryptoError() {
reason = e.ErrorMessage reason = e.ErrorMessage
} }
return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, v) return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v)
} }
// PackApplicationClose packs a packet that closes the connection with an application error. // PackApplicationClose packs a packet that closes the connection with an application error.
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, v) return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
} }
func (p *packetPacker) packConnectionClose( func (p *packetPacker) packConnectionClose(
@ -212,6 +175,7 @@ func (p *packetPacker) packConnectionClose(
errorCode uint64, errorCode uint64,
frameType uint64, frameType uint64,
reason string, reason string,
maxPacketSize protocol.ByteCount,
v protocol.VersionNumber, v protocol.VersionNumber,
) (*coalescedPacket, error) { ) (*coalescedPacket, error) {
var sealers [4]sealer var sealers [4]sealer
@ -241,7 +205,7 @@ func (p *packetPacker) packConnectionClose(
ccf.ReasonPhrase = "" ccf.ReasonPhrase = ""
} }
pl := payload{ pl := payload{
frames: []*ackhandler.Frame{{Frame: ccf}}, frames: []ackhandler.Frame{{Frame: ccf}},
length: ccf.Length(v), length: ccf.Length(v),
} }
@ -293,20 +257,14 @@ func (p *packetPacker) packConnectionClose(
} }
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial { if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size) paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
} }
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: oneRTTPacketNumberLen,
KeyPhase: keyPhase,
}
} else { } else {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
if err != nil { if err != nil {
@ -342,25 +300,21 @@ func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnL
} }
// size is the expected size of the packet, if no padding was applied. // size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) initialPaddingLen(frames []*ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded. // For the server, only ack-eliciting Initial packets need to be padded.
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
return 0 return 0
} }
if size >= p.maxPacketSize { if currentSize >= maxPacketSize {
return 0 return 0
} }
return p.maxPacketSize - size return maxPacketSize - currentSize
} }
// PackCoalescedPacket packs a new packet. // PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces. // It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed. // It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
maxPacketSize := p.maxPacketSize
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
var ( var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
@ -442,7 +396,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe
longHdrPackets: make([]*longHeaderPacket, 0, 3), longHdrPackets: make([]*longHeaderPacket, 0, 3),
} }
if initialPayload.length > 0 { if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size) padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil { if err != nil {
return nil, err return nil, err
@ -463,48 +417,44 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe
} }
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload.length > 0 { } else if oneRTTPayload.length > 0 {
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: oneRTTPacketNumberLen,
KeyPhase: kp,
}
} }
return packet, nil return packet, nil
} }
// PackPacket packs a packet in the application data packet number space. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed. // It should be called after the handshake is confirmed.
func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
return packet, buf, err
}
// AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxPacketSize, v)
}
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer() sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
return shortHeaderPacket{}, nil, err return shortHeaderPacket{}, err
} }
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID() connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true, v) pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v)
if pl.length == 0 { if pl.length == 0 {
return shortHeaderPacket{}, nil, errNothingToPack return shortHeaderPacket{}, errNothingToPack
} }
kp := sealer.KeyPhase() kp := sealer.KeyPhase()
buffer := getPacketBuffer()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false, v) return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
if err != nil {
return shortHeaderPacket{}, nil, err
}
return shortHeaderPacket{
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, buffer, nil
} }
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
@ -519,14 +469,17 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
} }
var s cryptoStream var s cryptoStream
var handler ackhandler.FrameHandler
var hasRetransmission bool var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here. //nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
s = p.initialStream s = p.initialStream
handler = p.retransmissionQueue.InitialAckHandler()
hasRetransmission = p.retransmissionQueue.HasInitialData() hasRetransmission = p.retransmissionQueue.HasInitialData()
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
s = p.handshakeStream s = p.handshakeStream
handler = p.retransmissionQueue.HandshakeAckHandler()
hasRetransmission = p.retransmissionQueue.HasHandshakeData() hasRetransmission = p.retransmissionQueue.HasHandshakeData()
} }
@ -550,27 +503,27 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
maxPacketSize -= hdr.GetLength(v) maxPacketSize -= hdr.GetLength(v)
if hasRetransmission { if hasRetransmission {
for { for {
var f wire.Frame var f ackhandler.Frame
//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.InitialAckHandler()
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v) f.Frame = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.HandshakeAckHandler()
} }
if f == nil { if f.Frame == nil {
break break
} }
af := ackhandler.GetFrame() pl.frames = append(pl.frames, f)
af.Frame = f frameLen := f.Frame.Length(v)
pl.frames = append(pl.frames, af)
frameLen := f.Length(v)
pl.length += frameLen pl.length += frameLen
maxPacketSize -= frameLen maxPacketSize -= frameLen
} }
} else if s.HasData() { } else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize) cf := s.PopCryptoFrame(maxPacketSize)
pl.frames = []*ackhandler.Frame{{Frame: cf}} pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
pl.length += cf.Length(v) pl.length += cf.Length(v)
} }
return hdr, pl return hdr, pl
@ -595,18 +548,14 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
// check if we have anything to send // check if we have anything to send
if len(pl.frames) == 0 { if len(pl.frames) == 0 && len(pl.streamFrames) == 0 {
if pl.ack == nil { if pl.ack == nil {
return payload{} return payload{}
} }
// the packet only contains an ACK // the packet only contains an ACK
if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
ping := &wire.PingFrame{} ping := &wire.PingFrame{}
// don't retransmit the PING frame when it is lost pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping})
af := ackhandler.GetFrame()
af.Frame = ping
af.OnLost = func(wire.Frame) {}
pl.frames = append(pl.frames, af)
pl.length += ping.Length(v) pl.length += ping.Length(v)
p.numNonAckElicitingAcks = 0 p.numNonAckElicitingAcks = 0
} else { } else {
@ -621,15 +570,12 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
return payload{ return payload{ack: ack, length: ack.Length(v)}
ack: ack,
length: ack.Length(v),
}
} }
return payload{} return payload{}
} }
pl := payload{frames: make([]*ackhandler.Frame, 0, 1)} pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
hasData := p.framer.HasData() hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData() hasRetransmission := p.retransmissionQueue.HasAppData()
@ -647,11 +593,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if f := p.datagramQueue.Peek(); f != nil { if f := p.datagramQueue.Peek(); f != nil {
size := f.Length(v) size := f.Length(v)
if size <= maxFrameSize-pl.length { if size <= maxFrameSize-pl.length {
af := ackhandler.GetFrame() pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
af.Frame = f
// set it to a no-op. Then we won't set the default callback, which would retransmit the frame.
af.OnLost = func(wire.Frame) {}
pl.frames = append(pl.frames, af)
pl.length += size pl.length += size
p.datagramQueue.Pop() p.datagramQueue.Pop()
} }
@ -672,25 +614,28 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if f == nil { if f == nil {
break break
} }
af := ackhandler.GetFrame() pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AppDataAckHandler()})
af.Frame = f
pl.frames = append(pl.frames, af)
pl.length += f.Length(v) pl.length += f.Length(v)
} }
} }
if hasData { if hasData {
var lengthAdded protocol.ByteCount var lengthAdded protocol.ByteCount
startLen := len(pl.frames)
pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v)
pl.length += lengthAdded pl.length += lengthAdded
// add handlers for the control frames that were added
for i := startLen; i < len(pl.frames); i++ {
pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
}
pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length, v) pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
pl.length += lengthAdded pl.length += lengthAdded
} }
return pl return pl
} }
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer() s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
@ -700,23 +645,17 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
connID := p.getDestConnID() connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v)
if pl.length == 0 { if pl.length == 0 {
return nil, nil return nil, nil
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer} packet := &coalescedPacket{buffer: buffer}
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}
return packet, nil return packet, nil
} }
@ -731,14 +670,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
var err error var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer() sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v)
default: default:
panic("unknown encryption level") panic("unknown encryption level")
} }
@ -751,7 +690,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial { if encLevel == protocol.EncryptionInitial {
padding = p.initialPaddingLen(pl.frames, size) padding = p.initialPaddingLen(pl.frames, size, maxPacketSize)
} }
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v)
@ -762,10 +701,10 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
return packet, nil return packet, nil
} }
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{ pl := payload{
frames: []*ackhandler.Frame{&ping}, frames: []ackhandler.Frame{ping},
length: ping.Length(v), length: ping.Frame.Length(v),
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer() s, err := p.cryptoSetup.Get1RTTSealer()
@ -776,17 +715,8 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead())
kp := s.KeyPhase() kp := s.KeyPhase()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true, v) packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v)
if err != nil { return packet, buffer, err
return shortHeaderPacket{}, nil, err
}
return shortHeaderPacket{
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, buffer, nil
} }
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
@ -829,23 +759,22 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
} }
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(encLevel)
if pn != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{ return &longHeaderPacket{
header: header, header: header,
ack: pl.ack, ack: pl.ack,
frames: pl.frames, frames: pl.frames,
length: protocol.ByteCount(len(raw)), streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil }, nil
} }
@ -856,11 +785,11 @@ func (p *packetPacker) appendShortHeaderPacket(
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
pl payload, pl payload,
padding protocol.ByteCount, padding, maxPacketSize protocol.ByteCount,
sealer sealer, sealer sealer,
isMTUProbePacket bool, isMTUProbePacket bool,
v protocol.VersionNumber, v protocol.VersionNumber,
) (*ackhandler.Packet, *wire.AckFrame, error) { ) (shortHeaderPacket, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) { if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
@ -871,48 +800,36 @@ func (p *packetPacker) appendShortHeaderPacket(
raw := buffer.Data[startLen:] raw := buffer.Data[startLen:]
raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
if err != nil { if err != nil {
return nil, nil, err return shortHeaderPacket{}, err
} }
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil { if err != nil {
return nil, nil, err return shortHeaderPacket{}, err
} }
if !isMTUProbePacket { if !isMTUProbePacket {
if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize {
return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize)
} }
} }
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen))
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// create the ackhandler.Packet if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn {
largestAcked := protocol.InvalidPacketNumber return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN)
if pl.ack != nil {
largestAcked = pl.ack.LargestAcked()
} }
for i := range pl.frames { return shortHeaderPacket{
if pl.frames[i].OnLost != nil { PacketNumber: pn,
continue PacketNumberLen: pnLen,
} KeyPhase: kp,
pl.frames[i].OnLost = p.retransmissionQueue.AddAppData StreamFrames: pl.streamFrames,
} Frames: pl.frames,
Ack: pl.ack,
ap := ackhandler.GetPacket() Length: protocol.ByteCount(len(raw)),
ap.PacketNumber = pn DestConnID: connID,
ap.LargestAcked = largestAcked IsPathMTUProbePacket: isMTUProbePacket,
ap.Frames = pl.frames }, nil
ap.Length = protocol.ByteCount(len(raw))
ap.EncryptionLevel = protocol.Encryption1RTT
ap.SendTime = time.Now()
ap.IsPathMTUProbePacket = isMTUProbePacket
return ap, pl.ack, nil
} }
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
@ -927,9 +844,16 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr
if paddingLen > 0 { if paddingLen > 0 {
raw = append(raw, make([]byte, paddingLen)...) raw = append(raw, make([]byte, paddingLen)...)
} }
for _, frame := range pl.frames { for _, f := range pl.frames {
var err error var err error
raw, err = frame.Append(raw, v) raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
}
for _, f := range pl.streamFrames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -953,16 +877,3 @@ func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.Pack
func (p *packetPacker) SetToken(token []byte) { func (p *packetPacker) SetToken(token []byte) {
p.token = token p.token = token
} }
// When a higher MTU is discovered, use it.
func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) {
p.maxPacketSize = s
}
// If the peer sets a max_packet_size that's smaller than the size we're currently using,
// we need to reduce the size of packets we send.
func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) {
if params.MaxUDPPayloadSize != 0 {
p.maxPacketSize = utils.Min(p.maxPacketSize, params.MaxUDPPayloadSize)
}
}

View file

@ -179,6 +179,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF return true, bytesRead, io.EOF
} }
} }

View file

@ -3,6 +3,8 @@ package quic
import ( import (
"fmt" "fmt"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
@ -21,7 +23,23 @@ func newRetransmissionQueue() *retransmissionQueue {
return &retransmissionQueue{} return &retransmissionQueue{}
} }
func (q *retransmissionQueue) AddInitial(f wire.Frame) { // AddPing queues a ping.
// It is used when a probe packet needs to be sent
func (q *retransmissionQueue) AddPing(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // Cannot send probe packets for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
q.addInitial(&wire.PingFrame{})
case protocol.EncryptionHandshake:
q.addHandshake(&wire.PingFrame{})
case protocol.Encryption1RTT:
q.addAppData(&wire.PingFrame{})
default:
panic("unexpected encryption level")
}
}
func (q *retransmissionQueue) addInitial(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok { if cf, ok := f.(*wire.CryptoFrame); ok {
q.initialCryptoData = append(q.initialCryptoData, cf) q.initialCryptoData = append(q.initialCryptoData, cf)
return return
@ -29,7 +47,7 @@ func (q *retransmissionQueue) AddInitial(f wire.Frame) {
q.initial = append(q.initial, f) q.initial = append(q.initial, f)
} }
func (q *retransmissionQueue) AddHandshake(f wire.Frame) { func (q *retransmissionQueue) addHandshake(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok { if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshakeCryptoData = append(q.handshakeCryptoData, cf) q.handshakeCryptoData = append(q.handshakeCryptoData, cf)
return return
@ -49,7 +67,7 @@ func (q *retransmissionQueue) HasAppData() bool {
return len(q.appData) > 0 return len(q.appData) > 0
} }
func (q *retransmissionQueue) AddAppData(f wire.Frame) { func (q *retransmissionQueue) addAppData(f wire.Frame) {
if _, ok := f.(*wire.StreamFrame); ok { if _, ok := f.(*wire.StreamFrame); ok {
panic("STREAM frames are handled with their respective streams.") panic("STREAM frames are handled with their respective streams.")
} }
@ -127,3 +145,36 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) {
panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) panic(fmt.Sprintf("unexpected encryption level: %s", encLevel))
} }
} }
func (q *retransmissionQueue) InitialAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueInitialAckHandler)(q)
}
func (q *retransmissionQueue) HandshakeAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueHandshakeAckHandler)(q)
}
func (q *retransmissionQueue) AppDataAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueAppDataAckHandler)(q)
}
type retransmissionQueueInitialAckHandler retransmissionQueue
func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addInitial(f)
}
type retransmissionQueueHandshakeAckHandler retransmissionQueue
func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addHandshake(f)
}
type retransmissionQueueAppDataAckHandler retransmissionQueue
func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addAppData(f)
}

View file

@ -1,38 +1,65 @@
package quic package quic
import ( import (
"math"
"net" "net"
"github.com/quic-go/quic-go/internal/protocol"
) )
// A sendConn allows sending using a simple Write() on a non-connected packet conn. // A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface { type sendConn interface {
Write([]byte) error Write(b []byte, size protocol.ByteCount) error
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
RemoteAddr() net.Addr RemoteAddr() net.Addr
capabilities() connCapabilities
} }
type sconn struct { type sconn struct {
rawConn rawConn
remoteAddr net.Addr remoteAddr net.Addr
info *packetInfo info packetInfo
oob []byte oob []byte
} }
var _ sendConn = &sconn{} var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { func newSendConn(c rawConn, remote net.Addr) *sconn {
sc := &sconn{
rawConn: c,
remoteAddr: remote,
}
if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg
sc.oob = make([]byte, 0, 32)
}
return sc
}
func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn {
oob := info.OOB()
if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg
l := len(oob)
oob = append(oob, make([]byte, 32)...)
oob = oob[:l]
}
return &sconn{ return &sconn{
rawConn: c, rawConn: c,
remoteAddr: remote, remoteAddr: remote,
info: info, info: info,
oob: info.OOB(), oob: oob,
} }
} }
func (c *sconn) Write(p []byte) error { func (c *sconn) Write(p []byte, size protocol.ByteCount) error {
_, err := c.WritePacket(p, c.remoteAddr, c.oob) if size > math.MaxUint16 {
panic("size overflow")
}
_, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob)
return err return err
} }
@ -42,10 +69,10 @@ func (c *sconn) RemoteAddr() net.Addr {
func (c *sconn) LocalAddr() net.Addr { func (c *sconn) LocalAddr() net.Addr {
addr := c.rawConn.LocalAddr() addr := c.rawConn.LocalAddr()
if c.info != nil { if c.info.addr.IsValid() {
if udpAddr, ok := addr.(*net.UDPAddr); ok { if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrCopy := *udpAddr addrCopy := *udpAddr
addrCopy.IP = c.info.addr addrCopy.IP = c.info.addr.AsSlice()
addr = &addrCopy addr = &addrCopy
} }
} }

View file

@ -1,15 +1,22 @@
package quic package quic
import "github.com/quic-go/quic-go/internal/protocol"
type sender interface { type sender interface {
Send(p *packetBuffer) Send(p *packetBuffer, packetSize protocol.ByteCount)
Run() error Run() error
WouldBlock() bool WouldBlock() bool
Available() <-chan struct{} Available() <-chan struct{}
Close() Close()
} }
type queueEntry struct {
buf *packetBuffer
size protocol.ByteCount
}
type sendQueue struct { type sendQueue struct {
queue chan *packetBuffer queue chan queueEntry
closeCalled chan struct{} // runStopped when Close() is called closeCalled chan struct{} // runStopped when Close() is called
runStopped chan struct{} // runStopped when the run loop returns runStopped chan struct{} // runStopped when the run loop returns
available chan struct{} available chan struct{}
@ -26,16 +33,16 @@ func newSendQueue(conn sendConn) sender {
runStopped: make(chan struct{}), runStopped: make(chan struct{}),
closeCalled: make(chan struct{}), closeCalled: make(chan struct{}),
available: make(chan struct{}, 1), available: make(chan struct{}, 1),
queue: make(chan *packetBuffer, sendQueueCapacity), queue: make(chan queueEntry, sendQueueCapacity),
} }
} }
// Send sends out a packet. It's guaranteed to not block. // Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic. // Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer) { func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) {
select { select {
case h.queue <- p: case h.queue <- queueEntry{buf: p, size: size}:
// clear available channel if we've reached capacity // clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity { if len(h.queue) == sendQueueCapacity {
select { select {
@ -69,8 +76,8 @@ func (h *sendQueue) Run() error {
h.closeCalled = nil // prevent this case from being selected again h.closeCalled = nil // prevent this case from being selected again
// make sure that all queued packets are actually sent out // make sure that all queued packets are actually sent out
shouldClose = true shouldClose = true
case p := <-h.queue: case e := <-h.queue:
if err := h.conn.Write(p.Data); err != nil { if err := h.conn.Write(e.buf.Data, e.size); err != nil {
// This additional check enables: // This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such, // 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and // 2. Path MTU discovery,and
@ -79,7 +86,7 @@ func (h *sendQueue) Run() error {
return err return err
} }
} }
p.Release() e.buf.Release()
select { select {
case h.available <- struct{}{}: case h.available <- struct{}{}:
default: default:

View file

@ -18,7 +18,7 @@ type sendStreamI interface {
SendStream SendStream
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
hasData() bool hasData() bool
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
closeForShutdown(error) closeForShutdown(error)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have. // maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool /* has more data to send */) { func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
s.mutex.Lock() s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil { if f != nil {
@ -207,13 +207,12 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
s.mutex.Unlock() s.mutex.Unlock()
if f == nil { if f == nil {
return nil, hasMoreData return ackhandler.StreamFrame{}, false, hasMoreData
} }
af := ackhandler.GetFrame() return ackhandler.StreamFrame{
af.Frame = f Frame: f,
af.OnLost = s.queueRetransmission Handler: (*sendStreamAckHandler)(s),
af.OnAcked = s.frameAcked }, true, hasMoreData
return af, hasMoreData
} }
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
@ -348,26 +347,6 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By
} }
} }
func (s *sendStream) frameAcked(f wire.Frame) {
f.(*wire.StreamFrame).PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := s.isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStream) isNewlyCompleted() bool { func (s *sendStream) isNewlyCompleted() bool {
completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
if completed && !s.completed { if completed && !s.completed {
@ -377,24 +356,6 @@ func (s *sendStream) isNewlyCompleted() bool {
return false return false
} }
func (s *sendStream) queueRetransmission(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.DataLenPresent = true
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.retransmissionQueue = append(s.retransmissionQueue, sf)
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
}
func (s *sendStream) Close() error { func (s *sendStream) Close() error {
s.mutex.Lock() s.mutex.Lock()
if s.closeForShutdownErr != nil { if s.closeForShutdownErr != nil {
@ -487,3 +448,45 @@ func (s *sendStream) signalWrite() {
default: default:
} }
} }
type sendStreamAckHandler sendStream
var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame)
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
sf.DataLenPresent = true
s.retransmissionQueue = append(s.retransmissionQueue, sf)
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
}

View file

@ -24,7 +24,7 @@ var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(receivedPacket)
shutdown() shutdown()
destroy(error) destroy(error)
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
@ -42,7 +42,7 @@ type packetHandlerManager interface {
type quicConn interface { type quicConn interface {
EarlyConnection EarlyConnection
earlyConnReady() <-chan struct{} earlyConnReady() <-chan struct{}
handlePacket(*receivedPacket) handlePacket(receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
run() error run() error
@ -51,7 +51,7 @@ type quicConn interface {
} }
type zeroRTTQueue struct { type zeroRTTQueue struct {
packets []*receivedPacket packets []receivedPacket
expiration time.Time expiration time.Time
} }
@ -72,7 +72,7 @@ type baseServer struct {
connHandler packetHandlerManager connHandler packetHandlerManager
onClose func() onClose func()
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
nextZeroRTTCleanup time.Time nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
@ -102,8 +102,8 @@ type baseServer struct {
errorChan chan struct{} errorChan chan struct{}
closed bool closed bool
running chan struct{} // closed as soon as run() returns running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan *receivedPacket versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan *receivedPacket invalidTokenQueue chan receivedPacket
connQueue chan quicConn connQueue chan quicConn
connQueueLen int32 // to be used as an atomic connQueueLen int32 // to be used as an atomic
@ -160,8 +160,7 @@ func (l *EarlyListener) Addr() net.Addr {
} }
// ListenAddr creates a QUIC server listening on a given address. // ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil and must contain a certificate configuration. // See Listen for more details.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
conn, err := listenUDP(addr) conn, err := listenUDP(addr)
if err != nil { if err != nil {
@ -195,16 +194,19 @@ func listenUDP(addr string) (*net.UDPConn, error) {
return net.ListenUDP("udp", udpAddr) return net.ListenUDP("udp", udpAddr)
} }
// Listen listens for QUIC connections on a given net.PacketConn. If the // Listen listens for QUIC connections on a given net.PacketConn.
// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write // will be used instead of ReadFrom and WriteTo to read/write packets.
// packets. A single net.PacketConn only be used for a single call to Listen. // A single net.PacketConn can only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial. QUIC connection //
// IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration. // The tls.Config must not be nil and must contain a certificate configuration.
// Furthermore, it must define an application control (using NextProtos). // Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used. // The quic.Config may be nil, in that case the default values will be used.
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for outgoing QUIC connections.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
tr := &Transport{Conn: conn, isSingleUse: true} tr := &Transport{Conn: conn, isSingleUse: true}
return tr.Listen(tlsConf, config) return tr.Listen(tlsConf, config)
@ -240,9 +242,9 @@ func newServer(
connQueue: make(chan quicConn), connQueue: make(chan quicConn),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
running: make(chan struct{}), running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan *receivedPacket, 4), versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan *receivedPacket, 4), invalidTokenQueue: make(chan receivedPacket, 4),
newConn: newConnection, newConn: newConnection,
tracer: tracer, tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"), logger: utils.DefaultLogger.WithPrefix("server"),
@ -343,7 +345,7 @@ func (s *baseServer) Addr() net.Addr {
return s.conn.LocalAddr() return s.conn.LocalAddr()
} }
func (s *baseServer) handlePacket(p *receivedPacket) { func (s *baseServer) handlePacket(p receivedPacket) {
select { select {
case s.receivedPackets <- p: case s.receivedPackets <- p:
default: default:
@ -354,7 +356,7 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
} }
} }
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime) defer s.cleanupZeroRTTQueues(p.rcvTime)
} }
@ -444,7 +446,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true return true
} }
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0) connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
@ -476,7 +478,7 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
} }
return false return false
} }
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)}
queue.packets[0] = p queue.packets[0] = p
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
queue.expiration = expiration queue.expiration = expiration
@ -532,7 +534,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
return true return true
} }
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release() p.buffer.Release()
if s.tracer != nil { if s.tracer != nil {
@ -630,7 +632,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
} }
conn = s.newConn( conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info), newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info),
s.connHandler, s.connHandler,
origDestConnID, origDestConnID,
retrySrcConnID, retrySrcConnID,
@ -704,7 +706,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
} }
} }
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection. // If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
@ -740,11 +742,11 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
} }
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB())
return err return err
} }
func (s *baseServer) enqueueInvalidToken(p *receivedPacket) { func (s *baseServer) enqueueInvalidToken(p receivedPacket) {
select { select {
case s.invalidTokenQueue <- p: case s.invalidTokenQueue <- p:
default: default:
@ -753,7 +755,7 @@ func (s *baseServer) enqueueInvalidToken(p *receivedPacket) {
} }
} }
func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) { func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
hdr, _, _, err := wire.ParsePacket(p.data) hdr, _, _, err := wire.ParsePacket(p.data)
@ -770,6 +772,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) {
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length] data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
@ -791,13 +795,13 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) {
} }
} }
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info)
} }
// sendError sends the error as a response to the packet received with header hdr // sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error {
b := getPacketBuffer() b := getPacketBuffer()
defer b.Release() defer b.Release()
@ -837,11 +841,11 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
} }
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB())
return err return err
} }
func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferInUse bool) { func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) {
select { select {
case s.versionNegotiationQueue <- p: case s.versionNegotiationQueue <- p:
return true return true
@ -851,7 +855,7 @@ func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferI
return false return false
} }
func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) { func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
v, err := wire.ParseVersion(p.data) v, err := wire.ParseVersion(p.data)
@ -875,7 +879,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) {
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
} }
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err) s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
} }

View file

@ -60,7 +60,7 @@ type streamI interface {
// for sending // for sending
hasData() bool hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }

View file

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"fmt"
"net" "net"
"syscall" "syscall"
"time" "time"
@ -15,16 +16,38 @@ import (
type OOBCapablePacketConn interface { type OOBCapablePacketConn interface {
net.PacketConn net.PacketConn
SyscallConn() (syscall.RawConn, error) SyscallConn() (syscall.RawConn, error)
SetReadBuffer(int) error
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
} }
var _ OOBCapablePacketConn = &net.UDPConn{} var _ OOBCapablePacketConn = &net.UDPConn{}
func wrapConn(pc net.PacketConn) (rawConn, error) { // OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance:
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does).
//
// It's only necessary to call this function explicitly if the application calls WriteTo
// after passing the connection to the Transport.
func OptimizeConn(c net.PacketConn) (net.PacketConn, error) {
return wrapConn(c)
}
func wrapConn(pc net.PacketConn) (interface {
net.PacketConn
rawConn
}, error,
) {
conn, ok := pc.(interface { conn, ok := pc.(interface {
SyscallConn() (syscall.RawConn, error) SyscallConn() (syscall.RawConn, error)
}) })
var supportsDF bool
if ok { if ok {
rawConn, err := conn.SyscallConn() rawConn, err := conn.SyscallConn()
if err != nil { if err != nil {
@ -33,7 +56,8 @@ func wrapConn(pc net.PacketConn) (rawConn, error) {
if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { if _, ok := pc.LocalAddr().(*net.UDPAddr); ok {
// Only set DF on sockets that we expect to be able to handle that configuration. // Only set DF on sockets that we expect to be able to handle that configuration.
err = setDF(rawConn) var err error
supportsDF, err = setDF(rawConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,32 +66,33 @@ func wrapConn(pc net.PacketConn) (rawConn, error) {
c, ok := pc.(OOBCapablePacketConn) c, ok := pc.(OOBCapablePacketConn)
if !ok { if !ok {
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
return &basicConn{PacketConn: pc}, nil return &basicConn{PacketConn: pc, supportsDF: supportsDF}, nil
} }
return newConn(c) return newConn(c, supportsDF)
} }
// The basicConn is the most trivial implementation of a connection. // The basicConn is the most trivial implementation of a rawConn.
// It reads a single packet from the underlying net.PacketConn. // It reads a single packet from the underlying net.PacketConn.
// It is used when // It is used when
// * the net.PacketConn is not a OOBCapablePacketConn, and // * the net.PacketConn is not a OOBCapablePacketConn, and
// * when the OS doesn't support OOB. // * when the OS doesn't support OOB.
type basicConn struct { type basicConn struct {
net.PacketConn net.PacketConn
supportsDF bool
} }
var _ rawConn = &basicConn{} var _ rawConn = &basicConn{}
func (c *basicConn) ReadPacket() (*receivedPacket, error) { func (c *basicConn) ReadPacket() (receivedPacket, error) {
buffer := getPacketBuffer() buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes // The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
n, addr, err := c.PacketConn.ReadFrom(buffer.Data) n, addr, err := c.PacketConn.ReadFrom(buffer.Data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
return &receivedPacket{ return receivedPacket{
remoteAddr: addr, remoteAddr: addr,
rcvTime: time.Now(), rcvTime: time.Now(),
data: buffer.Data[:n], data: buffer.Data[:n],
@ -75,6 +100,11 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) {
}, nil }, nil
} }
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) {
if uint16(len(b)) != packetSize {
panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b)))
}
return c.PacketConn.WriteTo(b, addr) return c.PacketConn.WriteTo(b, addr)
} }
func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} }

View file

@ -2,11 +2,13 @@
package quic package quic
import "syscall" import (
"syscall"
)
func setDF(rawConn syscall.RawConn) error { func setDF(syscall.RawConn) (bool, error) {
// no-op on unsupported platforms // no-op on unsupported platforms
return nil return false, nil
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {

View file

@ -4,14 +4,23 @@ package quic
import ( import (
"errors" "errors"
"log"
"os"
"strconv"
"syscall" "syscall"
"unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
) )
func setDF(rawConn syscall.RawConn) error { // UDP_SEGMENT controls GSO (Generic Segmentation Offload)
//
//nolint:stylecheck
const UDP_SEGMENT = 103
func setDF(rawConn syscall.RawConn) (bool, error) {
// Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented // and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error var errDFIPv4, errDFIPv6 error
@ -19,7 +28,7 @@ func setDF(rawConn syscall.RawConn) error {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO)
}); err != nil { }); err != nil {
return err return false, err
} }
switch { switch {
case errDFIPv4 == nil && errDFIPv6 == nil: case errDFIPv4 == nil && errDFIPv6 == nil:
@ -29,12 +38,46 @@ func setDF(rawConn syscall.RawConn) error {
case errDFIPv4 != nil && errDFIPv6 == nil: case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.") utils.DefaultLogger.Debugf("Setting DF for IPv6.")
case errDFIPv4 != nil && errDFIPv6 != nil: case errDFIPv4 != nil && errDFIPv6 != nil:
return errors.New("setting DF failed for both IPv4 and IPv6") return false, errors.New("setting DF failed for both IPv4 and IPv6")
} }
return nil return true, nil
}
func maybeSetGSO(rawConn syscall.RawConn) bool {
disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO"))
if disable {
return false
}
var setErr error
if err := rawConn.Control(func(fd uintptr) {
setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, UDP_SEGMENT, 1)
}); err != nil {
setErr = err
}
if setErr != nil {
log.Println("failed to enable GSO")
return false
}
return true
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {
// https://man7.org/linux/man-pages/man7/udp.7.html // https://man7.org/linux/man-pages/man7/udp.7.html
return errors.Is(err, unix.EMSGSIZE) return errors.Is(err, unix.EMSGSIZE)
} }
func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte {
startLen := len(b)
const dataLen = 2 // payload is a uint16
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_UDP
h.Type = UDP_SEGMENT
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
*(*uint16)(unsafe.Pointer(&b[offset])) = size
return b
}

View file

@ -12,20 +12,23 @@ import (
) )
const ( const (
// same for both IPv4 and IPv6 on Windows // IP_DONTFRAGMENT controls the Don't Fragment (DF) bit.
//
// It's the same code point for both IPv4 and IPv6 on Windows.
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html
//
//nolint:stylecheck
IP_DONTFRAGMENT = 14 IP_DONTFRAGMENT = 14
IPV6_DONTFRAG = 14
) )
func setDF(rawConn syscall.RawConn) error { func setDF(rawConn syscall.RawConn) (bool, error) {
var errDFIPv4, errDFIPv6 error var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) { if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1)
errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IP_DONTFRAGMENT, 1)
}); err != nil { }); err != nil {
return err return false, err
} }
switch { switch {
case errDFIPv4 == nil && errDFIPv6 == nil: case errDFIPv4 == nil && errDFIPv6 == nil:
@ -35,9 +38,9 @@ func setDF(rawConn syscall.RawConn) error {
case errDFIPv4 != nil && errDFIPv6 == nil: case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.") utils.DefaultLogger.Debugf("Setting DF for IPv6.")
case errDFIPv4 != nil && errDFIPv6 != nil: case errDFIPv4 != nil && errDFIPv6 != nil:
return errors.New("setting DF failed for both IPv4 and IPv6") return false, errors.New("setting DF failed for both IPv4 and IPv6")
} }
return nil return true, nil
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {

8
vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go generated vendored Normal file
View file

@ -0,0 +1,8 @@
//go:build darwin || freebsd
package quic
import "syscall"
func maybeSetGSO(_ syscall.RawConn) bool { return false }
func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil }

View file

@ -2,13 +2,20 @@
package quic package quic
import "net" import (
"net"
"net/netip"
)
func newConn(c net.PacketConn) (rawConn, error) { func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) {
return &basicConn{PacketConn: c}, nil return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil
} }
func inspectReadBuffer(any) (int, error) { return 0, nil } func inspectReadBuffer(any) (int, error) { return 0, nil }
func inspectWriteBuffer(any) (int, error) { return 0, nil } func inspectWriteBuffer(any) (int, error) { return 0, nil }
type packetInfo struct {
addr netip.Addr
}
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

View file

@ -5,7 +5,9 @@ package quic
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@ -61,11 +63,13 @@ type oobConn struct {
// Packets received from the kernel, but not yet returned by ReadPacket(). // Packets received from the kernel, but not yet returned by ReadPacket().
messages []ipv4.Message messages []ipv4.Message
buffers [batchSize]*packetBuffer buffers [batchSize]*packetBuffer
cap connCapabilities
} }
var _ rawConn = &oobConn{} var _ rawConn = &oobConn{}
func newConn(c OOBCapablePacketConn) (*oobConn, error) { func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
rawConn, err := c.SyscallConn() rawConn, err := c.SyscallConn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -122,6 +126,10 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) {
bc = ipv4.NewPacketConn(c) bc = ipv4.NewPacketConn(c)
} }
// Try enabling GSO.
// This will only succeed on Linux, and only for kernels > 4.18.
supportsGSO := maybeSetGSO(rawConn)
msgs := make([]ipv4.Message, batchSize) msgs := make([]ipv4.Message, batchSize)
for i := range msgs { for i := range msgs {
// preallocate the [][]byte // preallocate the [][]byte
@ -133,13 +141,15 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) {
messages: msgs, messages: msgs,
readPos: batchSize, readPos: batchSize,
} }
oobConn.cap.DF = supportsDF
oobConn.cap.GSO = supportsGSO
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
oobConn.messages[i].OOB = make([]byte, oobBufferSize) oobConn.messages[i].OOB = make([]byte, oobBufferSize)
} }
return oobConn, nil return oobConn, nil
} }
func (c *oobConn) ReadPacket() (*receivedPacket, error) { func (c *oobConn) ReadPacket() (receivedPacket, error) {
if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
c.messages = c.messages[:batchSize] c.messages = c.messages[:batchSize]
// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
@ -153,7 +163,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
n, err := c.batchConn.ReadBatch(c.messages, 0) n, err := c.batchConn.ReadBatch(c.messages, 0)
if n == 0 || err != nil { if n == 0 || err != nil {
return nil, err return receivedPacket{}, err
} }
c.messages = c.messages[:n] c.messages = c.messages[:n]
} }
@ -163,18 +173,21 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
c.readPos++ c.readPos++
data := msg.OOB[:msg.NN] data := msg.OOB[:msg.NN]
var ecn protocol.ECN p := receivedPacket{
var destIP net.IP remoteAddr: msg.Addr,
var ifIndex uint32 rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
buffer: buffer,
}
for len(data) > 0 { for len(data) > 0 {
hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
if hdr.Level == unix.IPPROTO_IP { if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type { switch hdr.Type {
case msgTypeIPTOS: case msgTypeIPTOS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv4PKTINFO: case msgTypeIPv4PKTINFO:
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
@ -182,80 +195,94 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
// struct in_addr ipi_addr; /* Header Destination // struct in_addr ipi_addr; /* Header Destination
// address */ // address */
// }; // };
ip := make([]byte, 4) var ip [4]byte
if len(body) == 12 { if len(body) == 12 {
ifIndex = binary.LittleEndian.Uint32(body) copy(ip[:], body[8:12])
copy(ip, body[8:12]) p.info.ifIndex = binary.LittleEndian.Uint32(body)
} else if len(body) == 4 { } else if len(body) == 4 {
// FreeBSD // FreeBSD
copy(ip, body) copy(ip[:], body)
} }
destIP = net.IP(ip) p.info.addr = netip.AddrFrom4(ip)
} }
} }
if hdr.Level == unix.IPPROTO_IPV6 { if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type { switch hdr.Type {
case unix.IPV6_TCLASS: case unix.IPV6_TCLASS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv6PKTINFO: case msgTypeIPv6PKTINFO:
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
if len(body) == 20 { if len(body) == 20 {
ip := make([]byte, 16) var ip [16]byte
copy(ip, body[:16]) copy(ip[:], body[:16])
destIP = net.IP(ip) p.info.addr = netip.AddrFrom16(ip)
ifIndex = binary.LittleEndian.Uint32(body[16:]) p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} }
} }
} }
data = remainder data = remainder
} }
var info *packetInfo return p, nil
if destIP != nil {
info = &packetInfo{
addr: destIP,
ifIndex: ifIndex,
}
}
return &receivedPacket{
remoteAddr: msg.Addr,
rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
ecn: ecn,
info: info,
buffer: buffer,
}, nil
} }
func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { // WriteTo (re)implements the net.PacketConn method.
// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return c.WritePacket(p, uint16(len(p)), addr, nil)
}
// WritePacket writes a new packet.
// If the connection supports GSO (and we activated GSO support before),
// it appends the UDP_SEGMENT size message to oob.
// Callers are advised to make sure that oob has a sufficient capacity,
// such that appending the UDP_SEGMENT size message doesn't cause an allocation.
func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) {
if c.cap.GSO {
oob = appendUDPSegmentSizeMsg(oob, packetSize)
} else if uint16(len(b)) != packetSize {
panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b)))
}
n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err return n, err
} }
func (c *oobConn) capabilities() connCapabilities {
return c.cap
}
type packetInfo struct {
addr netip.Addr
ifIndex uint32
}
func (info *packetInfo) OOB() []byte { func (info *packetInfo) OOB() []byte {
if info == nil { if info == nil {
return nil return nil
} }
if ip4 := info.addr.To4(); ip4 != nil { if info.addr.Is4() {
ip := info.addr.As4()
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */ // struct in_addr ipi_addr; /* Header Destination address */
// }; // };
cm := ipv4.ControlMessage{ cm := ipv4.ControlMessage{
Src: ip4, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()
} else if len(info.addr) == 16 { } else if info.addr.Is6() {
ip := info.addr.As16()
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
cm := ipv6.ControlMessage{ cm := ipv6.ControlMessage{
Src: info.addr, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()

View file

@ -3,13 +3,14 @@
package quic package quic
import ( import (
"net/netip"
"syscall" "syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func newConn(c OOBCapablePacketConn) (rawConn, error) { func newConn(c OOBCapablePacketConn, supportsDF bool) (*basicConn, error) {
return &basicConn{PacketConn: c}, nil return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil
} }
func inspectReadBuffer(c syscall.RawConn) (int, error) { func inspectReadBuffer(c syscall.RawConn) (int, error) {
@ -34,4 +35,8 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) {
return size, serr return size, serr
} }
type packetInfo struct {
addr netip.Addr
}
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

View file

@ -20,14 +20,19 @@ import (
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
) )
// The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as
// for dialing an arbitrary number of outgoing connections.
// A Transport handles a single net.PacketConn, and offers a range of configuration options
// compared to the simple helper functions like Listen and Dial that this package provides.
type Transport struct { type Transport struct {
// A single net.PacketConn can only be handled by one Transport. // A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports. // Bad things will happen if passed to multiple Transports.
// //
// If the connection satisfies the OOBCapablePacketConn interface // If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
// (as a net.UDPConn does), ECN and packet info support will be enabled. // After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
// In this case, optimized syscalls might be used, skipping the // Calling WriteTo is only valid on the connection returned by OptimizeConn.
// ReadFrom and WriteTo calls to read / write packets.
Conn net.PacketConn Conn net.PacketConn
// The length of the connection ID in bytes. // The length of the connection ID in bytes.
@ -44,6 +49,9 @@ type Transport struct {
// The StatelessResetKey is used to generate stateless reset tokens. // The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled. // If no key is configured, sending of stateless resets is disabled.
// It is highly recommended to configure a stateless reset key, as stateless resets
// allow the peer to quickly recover from crashes and reboots of this node.
// See section 10.3 of RFC 9000 for details.
StatelessResetKey *StatelessResetKey StatelessResetKey *StatelessResetKey
// A Tracer traces events that don't belong to a single QUIC connection. // A Tracer traces events that don't belong to a single QUIC connection.
@ -67,7 +75,7 @@ type Transport struct {
conn rawConn conn rawConn
closeQueue chan closePacket closeQueue chan closePacket
statelessResetQueue chan *receivedPacket statelessResetQueue chan receivedPacket
listening chan struct{} // is closed when listen returns listening chan struct{} // is closed when listen returns
closed bool closed bool
@ -148,7 +156,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
if t.isSingleUse { if t.isSingleUse {
onClose = func() { t.Close() } onClose = func() { t.Close() }
} }
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
} }
// DialEarly dials a new connection, attempting to use 0-RTT if possible. // DialEarly dials a new connection, attempting to use 0-RTT if possible.
@ -164,18 +172,25 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
if t.isSingleUse { if t.isSingleUse {
onClose = func() { t.Close() } onClose = func() { t.Close() }
} }
return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
} }
func (t *Transport) init(isServer bool) error { func (t *Transport) init(isServer bool) error {
t.initOnce.Do(func() { t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn) getMultiplexer().AddConn(t.Conn)
conn, err := wrapConn(t.Conn) var conn rawConn
if err != nil { if c, ok := t.Conn.(rawConn); ok {
t.initErr = err conn = c
return } else {
var err error
conn, err = wrapConn(t.Conn)
if err != nil {
t.initErr = err
return
}
} }
t.conn = conn
t.logger = utils.DefaultLogger // TODO: make this configurable t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn t.conn = conn
@ -183,7 +198,7 @@ func (t *Transport) init(isServer bool) error {
t.listening = make(chan struct{}) t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4) t.closeQueue = make(chan closePacket, 4)
t.statelessResetQueue = make(chan *receivedPacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4)
if t.ConnectionIDGenerator != nil { if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator t.connIDGenerator = t.ConnectionIDGenerator
@ -218,7 +233,7 @@ func (t *Transport) runSendQueue() {
case <-t.listening: case <-t.listening:
return return
case p := <-t.closeQueue: case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB())
case p := <-t.statelessResetQueue: case p := <-t.statelessResetQueue:
t.sendStatelessReset(p) t.sendStatelessReset(p)
} }
@ -230,14 +245,16 @@ func (t *Transport) runSendQueue() {
func (t *Transport) Close() error { func (t *Transport) Close() error {
t.close(errors.New("closing")) t.close(errors.New("closing"))
if t.createdConn { if t.createdConn {
if err := t.conn.Close(); err != nil { if err := t.Conn.Close(); err != nil {
return err return err
} }
} else { } else if t.conn != nil {
t.conn.SetReadDeadline(time.Now()) t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }() defer func() { t.conn.SetReadDeadline(time.Time{}) }()
} }
<-t.listening // wait until listening returns if t.listening != nil {
<-t.listening // wait until listening returns
}
return nil return nil
} }
@ -266,7 +283,9 @@ func (t *Transport) close(e error) {
return return
} }
t.handlerMap.Close(e) if t.handlerMap != nil {
t.handlerMap.Close(e)
}
if t.server != nil { if t.server != nil {
t.server.setCloseError(e) t.server.setCloseError(e)
} }
@ -325,7 +344,7 @@ func (t *Transport) listen(conn rawConn) {
} }
} }
func (t *Transport) handlePacket(p *receivedPacket) { func (t *Transport) handlePacket(p receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen) connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil { if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
@ -357,7 +376,7 @@ func (t *Transport) handlePacket(p *receivedPacket) {
t.server.handlePacket(p) t.server.handlePacket(p)
} }
func (t *Transport) maybeSendStatelessReset(p *receivedPacket) { func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
if t.StatelessResetKey == nil { if t.StatelessResetKey == nil {
p.buffer.Release() p.buffer.Release()
return return
@ -378,7 +397,7 @@ func (t *Transport) maybeSendStatelessReset(p *receivedPacket) {
} }
} }
func (t *Transport) sendStatelessReset(p *receivedPacket) { func (t *Transport) sendStatelessReset(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
connID, err := wire.ParseConnectionID(p.data, t.connIDLen) connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
@ -392,7 +411,7 @@ func (t *Transport) sendStatelessReset(p *receivedPacket) {
rand.Read(data) rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40 data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...) data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
} }
} }

View file

@ -0,0 +1,764 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package objectpath defines a naming scheme for types.Objects
// (that is, named entities in Go programs) relative to their enclosing
// package.
//
// Type-checker objects are canonical, so they are usually identified by
// their address in memory (a pointer), but a pointer has meaning only
// within one address space. By contrast, objectpath names allow the
// identity of an object to be sent from one program to another,
// establishing a correspondence between types.Object variables that are
// distinct but logically equivalent.
//
// A single object may have multiple paths. In this example,
//
// type A struct{ X int }
// type B A
//
// the field X has two paths due to its membership of both A and B.
// The For(obj) function always returns one of these paths, arbitrarily
// but consistently.
package objectpath
import (
"fmt"
"go/types"
"sort"
"strconv"
"strings"
"golang.org/x/tools/internal/typeparams"
_ "unsafe" // for go:linkname
)
// A Path is an opaque name that identifies a types.Object
// relative to its package. Conceptually, the name consists of a
// sequence of destructuring operations applied to the package scope
// to obtain the original object.
// The name does not include the package itself.
type Path string
// Encoding
//
// An object path is a textual and (with training) human-readable encoding
// of a sequence of destructuring operators, starting from a types.Package.
// The sequences represent a path through the package/object/type graph.
// We classify these operators by their type:
//
// PO package->object Package.Scope.Lookup
// OT object->type Object.Type
// TT type->type Type.{Elem,Key,Params,Results,Underlying} [EKPRU]
// TO type->object Type.{At,Field,Method,Obj} [AFMO]
//
// All valid paths start with a package and end at an object
// and thus may be defined by the regular language:
//
// objectpath = PO (OT TT* TO)*
//
// The concrete encoding follows directly:
// - The only PO operator is Package.Scope.Lookup, which requires an identifier.
// - The only OT operator is Object.Type,
// which we encode as '.' because dot cannot appear in an identifier.
// - The TT operators are encoded as [EKPRUTC];
// one of these (TypeParam) requires an integer operand,
// which is encoded as a string of decimal digits.
// - The TO operators are encoded as [AFMO];
// three of these (At,Field,Method) require an integer operand,
// which is encoded as a string of decimal digits.
// These indices are stable across different representations
// of the same package, even source and export data.
// The indices used are implementation specific and may not correspond to
// the argument to the go/types function.
//
// In the example below,
//
// package p
//
// type T interface {
// f() (a string, b struct{ X int })
// }
//
// field X has the path "T.UM0.RA1.F0",
// representing the following sequence of operations:
//
// p.Lookup("T") T
// .Type().Underlying().Method(0). f
// .Type().Results().At(1) b
// .Type().Field(0) X
//
// The encoding is not maximally compact---every R or P is
// followed by an A, for example---but this simplifies the
// encoder and decoder.
const (
// object->type operators
opType = '.' // .Type() (Object)
// type->type operators
opElem = 'E' // .Elem() (Pointer, Slice, Array, Chan, Map)
opKey = 'K' // .Key() (Map)
opParams = 'P' // .Params() (Signature)
opResults = 'R' // .Results() (Signature)
opUnderlying = 'U' // .Underlying() (Named)
opTypeParam = 'T' // .TypeParams.At(i) (Named, Signature)
opConstraint = 'C' // .Constraint() (TypeParam)
// type->object operators
opAt = 'A' // .At(i) (Tuple)
opField = 'F' // .Field(i) (Struct)
opMethod = 'M' // .Method(i) (Named or Interface; not Struct: "promoted" names are ignored)
opObj = 'O' // .Obj() (Named, TypeParam)
)
// For is equivalent to new(Encoder).For(obj).
//
// It may be more efficient to reuse a single Encoder across several calls.
func For(obj types.Object) (Path, error) {
return new(Encoder).For(obj)
}
// An Encoder amortizes the cost of encoding the paths of multiple objects.
// The zero value of an Encoder is ready to use.
type Encoder struct {
scopeNamesMemo map[*types.Scope][]string // memoization of Scope.Names()
namedMethodsMemo map[*types.Named][]*types.Func // memoization of namedMethods()
}
// For returns the path to an object relative to its package,
// or an error if the object is not accessible from the package's Scope.
//
// The For function guarantees to return a path only for the following objects:
// - package-level types
// - exported package-level non-types
// - methods
// - parameter and result variables
// - struct fields
// These objects are sufficient to define the API of their package.
// The objects described by a package's export data are drawn from this set.
//
// For does not return a path for predeclared names, imported package
// names, local names, and unexported package-level names (except
// types).
//
// Example: given this definition,
//
// package p
//
// type T interface {
// f() (a string, b struct{ X int })
// }
//
// For(X) would return a path that denotes the following sequence of operations:
//
// p.Scope().Lookup("T") (TypeName T)
// .Type().Underlying().Method(0). (method Func f)
// .Type().Results().At(1) (field Var b)
// .Type().Field(0) (field Var X)
//
// where p is the package (*types.Package) to which X belongs.
func (enc *Encoder) For(obj types.Object) (Path, error) {
pkg := obj.Pkg()
// This table lists the cases of interest.
//
// Object Action
// ------ ------
// nil reject
// builtin reject
// pkgname reject
// label reject
// var
// package-level accept
// func param/result accept
// local reject
// struct field accept
// const
// package-level accept
// local reject
// func
// package-level accept
// init functions reject
// concrete method accept
// interface method accept
// type
// package-level accept
// local reject
//
// The only accessible package-level objects are members of pkg itself.
//
// The cases are handled in four steps:
//
// 1. reject nil and builtin
// 2. accept package-level objects
// 3. reject obviously invalid objects
// 4. search the API for the path to the param/result/field/method.
// 1. reference to nil or builtin?
if pkg == nil {
return "", fmt.Errorf("predeclared %s has no path", obj)
}
scope := pkg.Scope()
// 2. package-level object?
if scope.Lookup(obj.Name()) == obj {
// Only exported objects (and non-exported types) have a path.
// Non-exported types may be referenced by other objects.
if _, ok := obj.(*types.TypeName); !ok && !obj.Exported() {
return "", fmt.Errorf("no path for non-exported %v", obj)
}
return Path(obj.Name()), nil
}
// 3. Not a package-level object.
// Reject obviously non-viable cases.
switch obj := obj.(type) {
case *types.TypeName:
if _, ok := obj.Type().(*typeparams.TypeParam); !ok {
// With the exception of type parameters, only package-level type names
// have a path.
return "", fmt.Errorf("no path for %v", obj)
}
case *types.Const, // Only package-level constants have a path.
*types.Label, // Labels are function-local.
*types.PkgName: // PkgNames are file-local.
return "", fmt.Errorf("no path for %v", obj)
case *types.Var:
// Could be:
// - a field (obj.IsField())
// - a func parameter or result
// - a local var.
// Sadly there is no way to distinguish
// a param/result from a local
// so we must proceed to the find.
case *types.Func:
// A func, if not package-level, must be a method.
if recv := obj.Type().(*types.Signature).Recv(); recv == nil {
return "", fmt.Errorf("func is not a method: %v", obj)
}
if path, ok := enc.concreteMethod(obj); ok {
// Fast path for concrete methods that avoids looping over scope.
return path, nil
}
default:
panic(obj)
}
// 4. Search the API for the path to the var (field/param/result) or method.
// First inspect package-level named types.
// In the presence of path aliases, these give
// the best paths because non-types may
// refer to types, but not the reverse.
empty := make([]byte, 0, 48) // initial space
names := enc.scopeNames(scope)
for _, name := range names {
o := scope.Lookup(name)
tname, ok := o.(*types.TypeName)
if !ok {
continue // handle non-types in second pass
}
path := append(empty, name...)
path = append(path, opType)
T := o.Type()
if tname.IsAlias() {
// type alias
if r := find(obj, T, path, nil); r != nil {
return Path(r), nil
}
} else {
if named, _ := T.(*types.Named); named != nil {
if r := findTypeParam(obj, typeparams.ForNamed(named), path, nil); r != nil {
// generic named type
return Path(r), nil
}
}
// defined (named) type
if r := find(obj, T.Underlying(), append(path, opUnderlying), nil); r != nil {
return Path(r), nil
}
}
}
// Then inspect everything else:
// non-types, and declared methods of defined types.
for _, name := range names {
o := scope.Lookup(name)
path := append(empty, name...)
if _, ok := o.(*types.TypeName); !ok {
if o.Exported() {
// exported non-type (const, var, func)
if r := find(obj, o.Type(), append(path, opType), nil); r != nil {
return Path(r), nil
}
}
continue
}
// Inspect declared methods of defined types.
if T, ok := o.Type().(*types.Named); ok {
path = append(path, opType)
// Note that method index here is always with respect
// to canonical ordering of methods, regardless of how
// they appear in the underlying type.
for i, m := range enc.namedMethods(T) {
path2 := appendOpArg(path, opMethod, i)
if m == obj {
return Path(path2), nil // found declared method
}
if r := find(obj, m.Type(), append(path2, opType), nil); r != nil {
return Path(r), nil
}
}
}
}
return "", fmt.Errorf("can't find path for %v in %s", obj, pkg.Path())
}
func appendOpArg(path []byte, op byte, arg int) []byte {
path = append(path, op)
path = strconv.AppendInt(path, int64(arg), 10)
return path
}
// concreteMethod returns the path for meth, which must have a non-nil receiver.
// The second return value indicates success and may be false if the method is
// an interface method or if it is an instantiated method.
//
// This function is just an optimization that avoids the general scope walking
// approach. You are expected to fall back to the general approach if this
// function fails.
func (enc *Encoder) concreteMethod(meth *types.Func) (Path, bool) {
// Concrete methods can only be declared on package-scoped named types. For
// that reason we can skip the expensive walk over the package scope: the
// path will always be package -> named type -> method. We can trivially get
// the type name from the receiver, and only have to look over the type's
// methods to find the method index.
//
// Methods on generic types require special consideration, however. Consider
// the following package:
//
// L1: type S[T any] struct{}
// L2: func (recv S[A]) Foo() { recv.Bar() }
// L3: func (recv S[B]) Bar() { }
// L4: type Alias = S[int]
// L5: func _[T any]() { var s S[int]; s.Foo() }
//
// The receivers of methods on generic types are instantiations. L2 and L3
// instantiate S with the type-parameters A and B, which are scoped to the
// respective methods. L4 and L5 each instantiate S with int. Each of these
// instantiations has its own method set, full of methods (and thus objects)
// with receivers whose types are the respective instantiations. In other
// words, we have
//
// S[A].Foo, S[A].Bar
// S[B].Foo, S[B].Bar
// S[int].Foo, S[int].Bar
//
// We may thus be trying to produce object paths for any of these objects.
//
// S[A].Foo and S[B].Bar are the origin methods, and their paths are S.Foo
// and S.Bar, which are the paths that this function naturally produces.
//
// S[A].Bar, S[B].Foo, and both methods on S[int] are instantiations that
// don't correspond to the origin methods. For S[int], this is significant.
// The most precise object path for S[int].Foo, for example, is Alias.Foo,
// not S.Foo. Our function, however, would produce S.Foo, which would
// resolve to a different object.
//
// For S[A].Bar and S[B].Foo it could be argued that S.Bar and S.Foo are
// still the correct paths, since only the origin methods have meaningful
// paths. But this is likely only true for trivial cases and has edge cases.
// Since this function is only an optimization, we err on the side of giving
// up, deferring to the slower but definitely correct algorithm. Most users
// of objectpath will only be giving us origin methods, anyway, as referring
// to instantiated methods is usually not useful.
if typeparams.OriginMethod(meth) != meth {
return "", false
}
recvT := meth.Type().(*types.Signature).Recv().Type()
if ptr, ok := recvT.(*types.Pointer); ok {
recvT = ptr.Elem()
}
named, ok := recvT.(*types.Named)
if !ok {
return "", false
}
if types.IsInterface(named) {
// Named interfaces don't have to be package-scoped
//
// TODO(dominikh): opt: if scope.Lookup(name) == named, then we can apply this optimization to interface
// methods, too, I think.
return "", false
}
// Preallocate space for the name, opType, opMethod, and some digits.
name := named.Obj().Name()
path := make([]byte, 0, len(name)+8)
path = append(path, name...)
path = append(path, opType)
for i, m := range enc.namedMethods(named) {
if m == meth {
path = appendOpArg(path, opMethod, i)
return Path(path), true
}
}
// Due to golang/go#59944, go/types fails to associate the receiver with
// certain methods on cgo types.
//
// TODO(rfindley): replace this panic once golang/go#59944 is fixed in all Go
// versions gopls supports.
return "", false
// panic(fmt.Sprintf("couldn't find method %s on type %s; methods: %#v", meth, named, enc.namedMethods(named)))
}
// find finds obj within type T, returning the path to it, or nil if not found.
//
// The seen map is used to short circuit cycles through type parameters. If
// nil, it will be allocated as necessary.
func find(obj types.Object, T types.Type, path []byte, seen map[*types.TypeName]bool) []byte {
switch T := T.(type) {
case *types.Basic, *types.Named:
// Named types belonging to pkg were handled already,
// so T must belong to another package. No path.
return nil
case *types.Pointer:
return find(obj, T.Elem(), append(path, opElem), seen)
case *types.Slice:
return find(obj, T.Elem(), append(path, opElem), seen)
case *types.Array:
return find(obj, T.Elem(), append(path, opElem), seen)
case *types.Chan:
return find(obj, T.Elem(), append(path, opElem), seen)
case *types.Map:
if r := find(obj, T.Key(), append(path, opKey), seen); r != nil {
return r
}
return find(obj, T.Elem(), append(path, opElem), seen)
case *types.Signature:
if r := findTypeParam(obj, typeparams.ForSignature(T), path, seen); r != nil {
return r
}
if r := find(obj, T.Params(), append(path, opParams), seen); r != nil {
return r
}
return find(obj, T.Results(), append(path, opResults), seen)
case *types.Struct:
for i := 0; i < T.NumFields(); i++ {
fld := T.Field(i)
path2 := appendOpArg(path, opField, i)
if fld == obj {
return path2 // found field var
}
if r := find(obj, fld.Type(), append(path2, opType), seen); r != nil {
return r
}
}
return nil
case *types.Tuple:
for i := 0; i < T.Len(); i++ {
v := T.At(i)
path2 := appendOpArg(path, opAt, i)
if v == obj {
return path2 // found param/result var
}
if r := find(obj, v.Type(), append(path2, opType), seen); r != nil {
return r
}
}
return nil
case *types.Interface:
for i := 0; i < T.NumMethods(); i++ {
m := T.Method(i)
path2 := appendOpArg(path, opMethod, i)
if m == obj {
return path2 // found interface method
}
if r := find(obj, m.Type(), append(path2, opType), seen); r != nil {
return r
}
}
return nil
case *typeparams.TypeParam:
name := T.Obj()
if name == obj {
return append(path, opObj)
}
if seen[name] {
return nil
}
if seen == nil {
seen = make(map[*types.TypeName]bool)
}
seen[name] = true
if r := find(obj, T.Constraint(), append(path, opConstraint), seen); r != nil {
return r
}
return nil
}
panic(T)
}
func findTypeParam(obj types.Object, list *typeparams.TypeParamList, path []byte, seen map[*types.TypeName]bool) []byte {
for i := 0; i < list.Len(); i++ {
tparam := list.At(i)
path2 := appendOpArg(path, opTypeParam, i)
if r := find(obj, tparam, path2, seen); r != nil {
return r
}
}
return nil
}
// Object returns the object denoted by path p within the package pkg.
func Object(pkg *types.Package, p Path) (types.Object, error) {
if p == "" {
return nil, fmt.Errorf("empty path")
}
pathstr := string(p)
var pkgobj, suffix string
if dot := strings.IndexByte(pathstr, opType); dot < 0 {
pkgobj = pathstr
} else {
pkgobj = pathstr[:dot]
suffix = pathstr[dot:] // suffix starts with "."
}
obj := pkg.Scope().Lookup(pkgobj)
if obj == nil {
return nil, fmt.Errorf("package %s does not contain %q", pkg.Path(), pkgobj)
}
// abstraction of *types.{Pointer,Slice,Array,Chan,Map}
type hasElem interface {
Elem() types.Type
}
// abstraction of *types.{Named,Signature}
type hasTypeParams interface {
TypeParams() *typeparams.TypeParamList
}
// abstraction of *types.{Named,TypeParam}
type hasObj interface {
Obj() *types.TypeName
}
// The loop state is the pair (t, obj),
// exactly one of which is non-nil, initially obj.
// All suffixes start with '.' (the only object->type operation),
// followed by optional type->type operations,
// then a type->object operation.
// The cycle then repeats.
var t types.Type
for suffix != "" {
code := suffix[0]
suffix = suffix[1:]
// Codes [AFM] have an integer operand.
var index int
switch code {
case opAt, opField, opMethod, opTypeParam:
rest := strings.TrimLeft(suffix, "0123456789")
numerals := suffix[:len(suffix)-len(rest)]
suffix = rest
i, err := strconv.Atoi(numerals)
if err != nil {
return nil, fmt.Errorf("invalid path: bad numeric operand %q for code %q", numerals, code)
}
index = int(i)
case opObj:
// no operand
default:
// The suffix must end with a type->object operation.
if suffix == "" {
return nil, fmt.Errorf("invalid path: ends with %q, want [AFMO]", code)
}
}
if code == opType {
if t != nil {
return nil, fmt.Errorf("invalid path: unexpected %q in type context", opType)
}
t = obj.Type()
obj = nil
continue
}
if t == nil {
return nil, fmt.Errorf("invalid path: code %q in object context", code)
}
// Inv: t != nil, obj == nil
switch code {
case opElem:
hasElem, ok := t.(hasElem) // Pointer, Slice, Array, Chan, Map
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want pointer, slice, array, chan or map)", code, t, t)
}
t = hasElem.Elem()
case opKey:
mapType, ok := t.(*types.Map)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want map)", code, t, t)
}
t = mapType.Key()
case opParams:
sig, ok := t.(*types.Signature)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want signature)", code, t, t)
}
t = sig.Params()
case opResults:
sig, ok := t.(*types.Signature)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want signature)", code, t, t)
}
t = sig.Results()
case opUnderlying:
named, ok := t.(*types.Named)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named)", code, t, t)
}
t = named.Underlying()
case opTypeParam:
hasTypeParams, ok := t.(hasTypeParams) // Named, Signature
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named or signature)", code, t, t)
}
tparams := hasTypeParams.TypeParams()
if n := tparams.Len(); index >= n {
return nil, fmt.Errorf("tuple index %d out of range [0-%d)", index, n)
}
t = tparams.At(index)
case opConstraint:
tparam, ok := t.(*typeparams.TypeParam)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want type parameter)", code, t, t)
}
t = tparam.Constraint()
case opAt:
tuple, ok := t.(*types.Tuple)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want tuple)", code, t, t)
}
if n := tuple.Len(); index >= n {
return nil, fmt.Errorf("tuple index %d out of range [0-%d)", index, n)
}
obj = tuple.At(index)
t = nil
case opField:
structType, ok := t.(*types.Struct)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want struct)", code, t, t)
}
if n := structType.NumFields(); index >= n {
return nil, fmt.Errorf("field index %d out of range [0-%d)", index, n)
}
obj = structType.Field(index)
t = nil
case opMethod:
switch t := t.(type) {
case *types.Interface:
if index >= t.NumMethods() {
return nil, fmt.Errorf("method index %d out of range [0-%d)", index, t.NumMethods())
}
obj = t.Method(index) // Id-ordered
case *types.Named:
methods := namedMethods(t) // (unmemoized)
if index >= len(methods) {
return nil, fmt.Errorf("method index %d out of range [0-%d)", index, len(methods))
}
obj = methods[index] // Id-ordered
default:
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want interface or named)", code, t, t)
}
t = nil
case opObj:
hasObj, ok := t.(hasObj)
if !ok {
return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named or type param)", code, t, t)
}
obj = hasObj.Obj()
t = nil
default:
return nil, fmt.Errorf("invalid path: unknown code %q", code)
}
}
if obj.Pkg() != pkg {
return nil, fmt.Errorf("path denotes %s, which belongs to a different package", obj)
}
return obj, nil // success
}
// namedMethods returns the methods of a Named type in ascending Id order.
func namedMethods(named *types.Named) []*types.Func {
methods := make([]*types.Func, named.NumMethods())
for i := range methods {
methods[i] = named.Method(i)
}
sort.Slice(methods, func(i, j int) bool {
return methods[i].Id() < methods[j].Id()
})
return methods
}
// namedMethods is a memoization of the namedMethods function. Callers must not modify the result.
func (enc *Encoder) namedMethods(named *types.Named) []*types.Func {
m := enc.namedMethodsMemo
if m == nil {
m = make(map[*types.Named][]*types.Func)
enc.namedMethodsMemo = m
}
methods, ok := m[named]
if !ok {
methods = namedMethods(named) // allocates and sorts
m[named] = methods
}
return methods
}
// scopeNames is a memoization of scope.Names. Callers must not modify the result.
func (enc *Encoder) scopeNames(scope *types.Scope) []string {
m := enc.scopeNamesMemo
if m == nil {
m = make(map[*types.Scope][]string)
enc.scopeNamesMemo = m
}
names, ok := m[scope]
if !ok {
names = scope.Names() // allocates and sorts
m[scope] = names
}
return names
}

View file

@ -7,6 +7,18 @@
// Package gcimporter provides various functions for reading // Package gcimporter provides various functions for reading
// gc-generated object files that can be used to implement the // gc-generated object files that can be used to implement the
// Importer interface defined by the Go 1.5 standard library package. // Importer interface defined by the Go 1.5 standard library package.
//
// The encoding is deterministic: if the encoder is applied twice to
// the same types.Package data structure, both encodings are equal.
// This property may be important to avoid spurious changes in
// applications such as build systems.
//
// However, the encoder is not necessarily idempotent. Importing an
// exported package may yield a types.Package that, while it
// represents the same set of Go types as the original, may differ in
// the details of its internal representation. Because of these
// differences, re-encoding the imported package may yield a
// different, but equally valid, encoding of the package.
package gcimporter // import "golang.org/x/tools/internal/gcimporter" package gcimporter // import "golang.org/x/tools/internal/gcimporter"
import ( import (

View file

@ -44,12 +44,12 @@ func IExportShallow(fset *token.FileSet, pkg *types.Package) ([]byte, error) {
return out.Bytes(), err return out.Bytes(), err
} }
// IImportShallow decodes "shallow" types.Package data encoded by IExportShallow // IImportShallow decodes "shallow" types.Package data encoded by
// in the same executable. This function cannot import data from // IExportShallow in the same executable. This function cannot import data from
// cmd/compile or gcexportdata.Write. // cmd/compile or gcexportdata.Write.
func IImportShallow(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string, insert InsertType) (*types.Package, error) { func IImportShallow(fset *token.FileSet, getPackage GetPackageFunc, data []byte, path string, insert InsertType) (*types.Package, error) {
const bundle = false const bundle = false
pkgs, err := iimportCommon(fset, imports, data, bundle, path, insert) pkgs, err := iimportCommon(fset, getPackage, data, bundle, path, insert)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -85,7 +85,7 @@ const (
// If the export data version is not recognized or the format is otherwise // If the export data version is not recognized or the format is otherwise
// compromised, an error is returned. // compromised, an error is returned.
func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string) (int, *types.Package, error) { func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string) (int, *types.Package, error) {
pkgs, err := iimportCommon(fset, imports, data, false, path, nil) pkgs, err := iimportCommon(fset, GetPackageFromMap(imports), data, false, path, nil)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
@ -94,10 +94,33 @@ func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []
// IImportBundle imports a set of packages from the serialized package bundle. // IImportBundle imports a set of packages from the serialized package bundle.
func IImportBundle(fset *token.FileSet, imports map[string]*types.Package, data []byte) ([]*types.Package, error) { func IImportBundle(fset *token.FileSet, imports map[string]*types.Package, data []byte) ([]*types.Package, error) {
return iimportCommon(fset, imports, data, true, "", nil) return iimportCommon(fset, GetPackageFromMap(imports), data, true, "", nil)
} }
func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data []byte, bundle bool, path string, insert InsertType) (pkgs []*types.Package, err error) { // A GetPackageFunc is a function that gets the package with the given path
// from the importer state, creating it (with the specified name) if necessary.
// It is an abstraction of the map historically used to memoize package creation.
//
// Two calls with the same path must return the same package.
//
// If the given getPackage func returns nil, the import will fail.
type GetPackageFunc = func(path, name string) *types.Package
// GetPackageFromMap returns a GetPackageFunc that retrieves packages from the
// given map of package path -> package.
//
// The resulting func may mutate m: if a requested package is not found, a new
// package will be inserted into m.
func GetPackageFromMap(m map[string]*types.Package) GetPackageFunc {
return func(path, name string) *types.Package {
if _, ok := m[path]; !ok {
m[path] = types.NewPackage(path, name)
}
return m[path]
}
}
func iimportCommon(fset *token.FileSet, getPackage GetPackageFunc, data []byte, bundle bool, path string, insert InsertType) (pkgs []*types.Package, err error) {
const currentVersion = iexportVersionCurrent const currentVersion = iexportVersionCurrent
version := int64(-1) version := int64(-1)
if !debug { if !debug {
@ -195,10 +218,9 @@ func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data
if pkgPath == "" { if pkgPath == "" {
pkgPath = path pkgPath = path
} }
pkg := imports[pkgPath] pkg := getPackage(pkgPath, pkgName)
if pkg == nil { if pkg == nil {
pkg = types.NewPackage(pkgPath, pkgName) errorf("internal error: getPackage returned nil package for %s", pkgPath)
imports[pkgPath] = pkg
} else if pkg.Name() != pkgName { } else if pkg.Name() != pkgName {
errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path) errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path)
} }

View file

@ -12,6 +12,7 @@ package gcimporter
import ( import (
"go/token" "go/token"
"go/types" "go/types"
"sort"
"strings" "strings"
"golang.org/x/tools/internal/pkgbits" "golang.org/x/tools/internal/pkgbits"
@ -121,6 +122,16 @@ func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[st
iface.Complete() iface.Complete()
} }
// Imports() of pkg are all of the transitive packages that were loaded.
var imps []*types.Package
for _, imp := range pr.pkgs {
if imp != nil && imp != pkg {
imps = append(imps, imp)
}
}
sort.Sort(byPath(imps))
pkg.SetImports(imps)
pkg.MarkComplete() pkg.MarkComplete()
return pkg return pkg
} }
@ -260,39 +271,9 @@ func (r *reader) doPkg() *types.Package {
pkg := types.NewPackage(path, name) pkg := types.NewPackage(path, name)
r.p.imports[path] = pkg r.p.imports[path] = pkg
imports := make([]*types.Package, r.Len())
for i := range imports {
imports[i] = r.pkg()
}
pkg.SetImports(flattenImports(imports))
return pkg return pkg
} }
// flattenImports returns the transitive closure of all imported
// packages rooted from pkgs.
func flattenImports(pkgs []*types.Package) []*types.Package {
var res []*types.Package
seen := make(map[*types.Package]struct{})
for _, pkg := range pkgs {
if _, ok := seen[pkg]; ok {
continue
}
seen[pkg] = struct{}{}
res = append(res, pkg)
// pkg.Imports() is already flattened.
for _, pkg := range pkg.Imports() {
if _, ok := seen[pkg]; ok {
continue
}
seen[pkg] = struct{}{}
res = append(res, pkg)
}
}
return res
}
// @@@ Types // @@@ Types
func (r *reader) typ() types.Type { func (r *reader) typ() types.Type {

View file

@ -8,10 +8,12 @@ package gocommand
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"os" "os"
"reflect"
"regexp" "regexp"
"runtime" "runtime"
"strconv" "strconv"
@ -215,6 +217,18 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error {
cmd := exec.Command("go", goArgs...) cmd := exec.Command("go", goArgs...)
cmd.Stdout = stdout cmd.Stdout = stdout
cmd.Stderr = stderr cmd.Stderr = stderr
// cmd.WaitDelay was added only in go1.20 (see #50436).
if waitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); waitDelay.IsValid() {
// https://go.dev/issue/59541: don't wait forever copying stderr
// after the command has exited.
// After CL 484741 we copy stdout manually, so we we'll stop reading that as
// soon as ctx is done. However, we also don't want to wait around forever
// for stderr. Give a much-longer-than-reasonable delay and then assume that
// something has wedged in the kernel or runtime.
waitDelay.Set(reflect.ValueOf(30 * time.Second))
}
// On darwin the cwd gets resolved to the real path, which breaks anything that // On darwin the cwd gets resolved to the real path, which breaks anything that
// expects the working directory to keep the original path, including the // expects the working directory to keep the original path, including the
// go command when dealing with modules. // go command when dealing with modules.
@ -229,6 +243,7 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error {
cmd.Env = append(cmd.Env, "PWD="+i.WorkingDir) cmd.Env = append(cmd.Env, "PWD="+i.WorkingDir)
cmd.Dir = i.WorkingDir cmd.Dir = i.WorkingDir
} }
defer func(start time.Time) { log("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now()) defer func(start time.Time) { log("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now())
return runCmdContext(ctx, cmd) return runCmdContext(ctx, cmd)
@ -242,10 +257,85 @@ var DebugHangingGoCommands = false
// runCmdContext is like exec.CommandContext except it sends os.Interrupt // runCmdContext is like exec.CommandContext except it sends os.Interrupt
// before os.Kill. // before os.Kill.
func runCmdContext(ctx context.Context, cmd *exec.Cmd) error { func runCmdContext(ctx context.Context, cmd *exec.Cmd) (err error) {
if err := cmd.Start(); err != nil { // If cmd.Stdout is not an *os.File, the exec package will create a pipe and
// copy it to the Writer in a goroutine until the process has finished and
// either the pipe reaches EOF or command's WaitDelay expires.
//
// However, the output from 'go list' can be quite large, and we don't want to
// keep reading (and allocating buffers) if we've already decided we don't
// care about the output. We don't want to wait for the process to finish, and
// we don't wait to wait for the WaitDelay to expire either.
//
// Instead, if cmd.Stdout requires a copying goroutine we explicitly replace
// it with a pipe (which is an *os.File), which we can close in order to stop
// copying output as soon as we realize we don't care about it.
var stdoutW *os.File
if cmd.Stdout != nil {
if _, ok := cmd.Stdout.(*os.File); !ok {
var stdoutR *os.File
stdoutR, stdoutW, err = os.Pipe()
if err != nil {
return err
}
prevStdout := cmd.Stdout
cmd.Stdout = stdoutW
stdoutErr := make(chan error, 1)
go func() {
_, err := io.Copy(prevStdout, stdoutR)
if err != nil {
err = fmt.Errorf("copying stdout: %w", err)
}
stdoutErr <- err
}()
defer func() {
// We started a goroutine to copy a stdout pipe.
// Wait for it to finish, or terminate it if need be.
var err2 error
select {
case err2 = <-stdoutErr:
stdoutR.Close()
case <-ctx.Done():
stdoutR.Close()
// Per https://pkg.go.dev/os#File.Close, the call to stdoutR.Close
// should cause the Read call in io.Copy to unblock and return
// immediately, but we still need to receive from stdoutErr to confirm
// that that has happened.
<-stdoutErr
err2 = ctx.Err()
}
if err == nil {
err = err2
}
}()
// Per https://pkg.go.dev/os/exec#Cmd, “If Stdout and Stderr are the
// same writer, and have a type that can be compared with ==, at most
// one goroutine at a time will call Write.”
//
// Since we're starting a goroutine that writes to cmd.Stdout, we must
// also update cmd.Stderr so that that still holds.
func() {
defer func() { recover() }()
if cmd.Stderr == prevStdout {
cmd.Stderr = cmd.Stdout
}
}()
}
}
err = cmd.Start()
if stdoutW != nil {
// The child process has inherited the pipe file,
// so close the copy held in this process.
stdoutW.Close()
stdoutW = nil
}
if err != nil {
return err return err
} }
resChan := make(chan error, 1) resChan := make(chan error, 1)
go func() { go func() {
resChan <- cmd.Wait() resChan <- cmd.Wait()
@ -253,11 +343,14 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error {
// If we're interested in debugging hanging Go commands, stop waiting after a // If we're interested in debugging hanging Go commands, stop waiting after a
// minute and panic with interesting information. // minute and panic with interesting information.
if DebugHangingGoCommands { debug := DebugHangingGoCommands
if debug {
timer := time.NewTimer(1 * time.Minute)
defer timer.Stop()
select { select {
case err := <-resChan: case err := <-resChan:
return err return err
case <-time.After(1 * time.Minute): case <-timer.C:
HandleHangingGoCommand(cmd.Process) HandleHangingGoCommand(cmd.Process)
case <-ctx.Done(): case <-ctx.Done():
} }
@ -270,30 +363,25 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error {
} }
// Cancelled. Interrupt and see if it ends voluntarily. // Cancelled. Interrupt and see if it ends voluntarily.
cmd.Process.Signal(os.Interrupt) if err := cmd.Process.Signal(os.Interrupt); err == nil {
select { // (We used to wait only 1s but this proved
case err := <-resChan: // fragile on loaded builder machines.)
return err timer := time.NewTimer(5 * time.Second)
case <-time.After(time.Second): defer timer.Stop()
select {
case err := <-resChan:
return err
case <-timer.C:
}
} }
// Didn't shut down in response to interrupt. Kill it hard. // Didn't shut down in response to interrupt. Kill it hard.
// TODO(rfindley): per advice from bcmills@, it may be better to send SIGQUIT // TODO(rfindley): per advice from bcmills@, it may be better to send SIGQUIT
// on certain platforms, such as unix. // on certain platforms, such as unix.
if err := cmd.Process.Kill(); err != nil && DebugHangingGoCommands { if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) && debug {
// Don't panic here as this reliably fails on windows with EINVAL.
log.Printf("error killing the Go command: %v", err) log.Printf("error killing the Go command: %v", err)
} }
// See above: don't wait indefinitely if we're debugging hanging Go commands.
if DebugHangingGoCommands {
select {
case err := <-resChan:
return err
case <-time.After(10 * time.Second): // a shorter wait as resChan should return quickly following Kill
HandleHangingGoCommand(cmd.Process)
}
}
return <-resChan return <-resChan
} }

View file

@ -23,21 +23,11 @@ import (
func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) { func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) {
inv.Verb = "list" inv.Verb = "list"
inv.Args = []string{"-e", "-f", `{{context.ReleaseTags}}`, `--`, `unsafe`} inv.Args = []string{"-e", "-f", `{{context.ReleaseTags}}`, `--`, `unsafe`}
inv.Env = append(append([]string{}, inv.Env...), "GO111MODULE=off") inv.BuildFlags = nil // This is not a build command.
// Unset any unneeded flags, and remove them from BuildFlags, if they're
// present.
inv.ModFile = ""
inv.ModFlag = "" inv.ModFlag = ""
var buildFlags []string inv.ModFile = ""
for _, flag := range inv.BuildFlags { inv.Env = append(inv.Env[:len(inv.Env):len(inv.Env)], "GO111MODULE=off")
// Flags can be prefixed by one or two dashes.
f := strings.TrimPrefix(strings.TrimPrefix(flag, "-"), "-")
if strings.HasPrefix(f, "mod=") || strings.HasPrefix(f, "modfile=") {
continue
}
buildFlags = append(buildFlags, flag)
}
inv.BuildFlags = buildFlags
stdoutBytes, err := r.Run(ctx, inv) stdoutBytes, err := r.Run(ctx, inv)
if err != nil { if err != nil {
return 0, err return 0, err

View file

@ -414,9 +414,16 @@ func (p *pass) fix() ([]*ImportFix, bool) {
}) })
} }
} }
// Collecting fixes involved map iteration, so sort for stability. See
// golang/go#59976.
sortFixes(fixes)
// collect selected fixes in a separate slice, so that it can be sorted
// separately. Note that these fixes must occur after fixes to existing
// imports. TODO(rfindley): figure out why.
var selectedFixes []*ImportFix
for _, imp := range selected { for _, imp := range selected {
fixes = append(fixes, &ImportFix{ selectedFixes = append(selectedFixes, &ImportFix{
StmtInfo: ImportInfo{ StmtInfo: ImportInfo{
Name: p.importSpecName(imp), Name: p.importSpecName(imp),
ImportPath: imp.ImportPath, ImportPath: imp.ImportPath,
@ -425,8 +432,25 @@ func (p *pass) fix() ([]*ImportFix, bool) {
FixType: AddImport, FixType: AddImport,
}) })
} }
sortFixes(selectedFixes)
return fixes, true return append(fixes, selectedFixes...), true
}
func sortFixes(fixes []*ImportFix) {
sort.Slice(fixes, func(i, j int) bool {
fi, fj := fixes[i], fixes[j]
if fi.StmtInfo.ImportPath != fj.StmtInfo.ImportPath {
return fi.StmtInfo.ImportPath < fj.StmtInfo.ImportPath
}
if fi.StmtInfo.Name != fj.StmtInfo.Name {
return fi.StmtInfo.Name < fj.StmtInfo.Name
}
if fi.IdentName != fj.IdentName {
return fi.IdentName < fj.IdentName
}
return fi.FixType < fj.FixType
})
} }
// importSpecName gets the import name of imp in the import spec. // importSpecName gets the import name of imp in the import spec.

View file

@ -7,7 +7,9 @@
package tokeninternal package tokeninternal
import ( import (
"fmt"
"go/token" "go/token"
"sort"
"sync" "sync"
"unsafe" "unsafe"
) )
@ -57,3 +59,93 @@ func GetLines(file *token.File) []int {
panic("unexpected token.File size") panic("unexpected token.File size")
} }
} }
// AddExistingFiles adds the specified files to the FileSet if they
// are not already present. It panics if any pair of files in the
// resulting FileSet would overlap.
func AddExistingFiles(fset *token.FileSet, files []*token.File) {
// Punch through the FileSet encapsulation.
type tokenFileSet struct {
// This type remained essentially consistent from go1.16 to go1.21.
mutex sync.RWMutex
base int
files []*token.File
_ *token.File // changed to atomic.Pointer[token.File] in go1.19
}
// If the size of token.FileSet changes, this will fail to compile.
const delta = int64(unsafe.Sizeof(tokenFileSet{})) - int64(unsafe.Sizeof(token.FileSet{}))
var _ [-delta * delta]int
type uP = unsafe.Pointer
var ptr *tokenFileSet
*(*uP)(uP(&ptr)) = uP(fset)
ptr.mutex.Lock()
defer ptr.mutex.Unlock()
// Merge and sort.
newFiles := append(ptr.files, files...)
sort.Slice(newFiles, func(i, j int) bool {
return newFiles[i].Base() < newFiles[j].Base()
})
// Reject overlapping files.
// Discard adjacent identical files.
out := newFiles[:0]
for i, file := range newFiles {
if i > 0 {
prev := newFiles[i-1]
if file == prev {
continue
}
if prev.Base()+prev.Size()+1 > file.Base() {
panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)",
prev.Name(), prev.Base(), prev.Base()+prev.Size(),
file.Name(), file.Base(), file.Base()+file.Size()))
}
}
out = append(out, file)
}
newFiles = out
ptr.files = newFiles
// Advance FileSet.Base().
if len(newFiles) > 0 {
last := newFiles[len(newFiles)-1]
newBase := last.Base() + last.Size() + 1
if ptr.base < newBase {
ptr.base = newBase
}
}
}
// FileSetFor returns a new FileSet containing a sequence of new Files with
// the same base, size, and line as the input files, for use in APIs that
// require a FileSet.
//
// Precondition: the input files must be non-overlapping, and sorted in order
// of their Base.
func FileSetFor(files ...*token.File) *token.FileSet {
fset := token.NewFileSet()
for _, f := range files {
f2 := fset.AddFile(f.Name(), f.Base(), f.Size())
lines := GetLines(f)
f2.SetLines(lines)
}
return fset
}
// CloneFileSet creates a new FileSet holding all files in fset. It does not
// create copies of the token.Files in fset: they are added to the resulting
// FileSet unmodified.
func CloneFileSet(fset *token.FileSet) *token.FileSet {
var files []*token.File
fset.Iterate(func(f *token.File) bool {
files = append(files, f)
return true
})
newFileSet := token.NewFileSet()
AddExistingFiles(newFileSet, files)
return newFileSet
}

View file

@ -87,7 +87,6 @@ func IsTypeParam(t types.Type) bool {
func OriginMethod(fn *types.Func) *types.Func { func OriginMethod(fn *types.Func) *types.Func {
recv := fn.Type().(*types.Signature).Recv() recv := fn.Type().(*types.Signature).Recv()
if recv == nil { if recv == nil {
return fn return fn
} }
base := recv.Type() base := recv.Type()

View file

@ -11,6 +11,8 @@ import (
"go/types" "go/types"
"reflect" "reflect"
"unsafe" "unsafe"
"golang.org/x/tools/go/types/objectpath"
) )
func SetUsesCgo(conf *types.Config) bool { func SetUsesCgo(conf *types.Config) bool {
@ -50,3 +52,10 @@ func ReadGo116ErrorData(err types.Error) (code ErrorCode, start, end token.Pos,
} }
var SetGoVersion = func(conf *types.Config, version string) bool { return false } var SetGoVersion = func(conf *types.Config, version string) bool { return false }
// NewObjectpathEncoder returns a function closure equivalent to
// objectpath.For but amortized for multiple (sequential) calls.
// It is a temporary workaround, pending the approval of proposal 58668.
//
//go:linkname NewObjectpathFunc golang.org/x/tools/go/types/objectpath.newEncoderFor
func NewObjectpathFunc() func(types.Object) (objectpath.Path, error)

16
vendor/modules.txt vendored
View file

@ -15,14 +15,14 @@ github.com/davecgh/go-spew/spew
# github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 # github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185
## explicit ## explicit
github.com/dchest/safefile github.com/dchest/safefile
# github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 # github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572
## explicit; go 1.13 ## explicit; go 1.13
github.com/go-task/slim-sprig github.com/go-task/slim-sprig
# github.com/golang/mock v1.6.0 # github.com/golang/mock v1.6.0
## explicit; go 1.11 ## explicit; go 1.11
github.com/golang/mock/mockgen github.com/golang/mock/mockgen
github.com/golang/mock/mockgen/model github.com/golang/mock/mockgen/model
# github.com/golang/protobuf v1.5.2 # github.com/golang/protobuf v1.5.3
## explicit; go 1.9 ## explicit; go 1.9
github.com/golang/protobuf/proto github.com/golang/protobuf/proto
github.com/golang/protobuf/ptypes github.com/golang/protobuf/ptypes
@ -70,10 +70,10 @@ github.com/k-sone/critbitgo
# github.com/kardianos/service v1.2.2 # github.com/kardianos/service v1.2.2
## explicit; go 1.12 ## explicit; go 1.12
github.com/kardianos/service github.com/kardianos/service
# github.com/miekg/dns v1.1.54 # github.com/miekg/dns v1.1.55
## explicit; go 1.19 ## explicit; go 1.19
github.com/miekg/dns github.com/miekg/dns
# github.com/onsi/ginkgo/v2 v2.2.0 # github.com/onsi/ginkgo/v2 v2.9.5
## explicit; go 1.18 ## explicit; go 1.18
github.com/onsi/ginkgo/v2/config github.com/onsi/ginkgo/v2/config
github.com/onsi/ginkgo/v2/formatter github.com/onsi/ginkgo/v2/formatter
@ -112,7 +112,7 @@ github.com/quic-go/qtls-go1-19
# github.com/quic-go/qtls-go1-20 v0.2.2 # github.com/quic-go/qtls-go1-20 v0.2.2
## explicit; go 1.20 ## explicit; go 1.20
github.com/quic-go/qtls-go1-20 github.com/quic-go/qtls-go1-20
# github.com/quic-go/quic-go v0.35.1 # github.com/quic-go/quic-go v0.36.0
## explicit; go 1.19 ## explicit; go 1.19
github.com/quic-go/quic-go github.com/quic-go/quic-go
github.com/quic-go/quic-go/http3 github.com/quic-go/quic-go/http3
@ -126,6 +126,7 @@ github.com/quic-go/quic-go/internal/qerr
github.com/quic-go/quic-go/internal/qtls github.com/quic-go/quic-go/internal/qtls
github.com/quic-go/quic-go/internal/utils github.com/quic-go/quic-go/internal/utils
github.com/quic-go/quic-go/internal/utils/linkedlist github.com/quic-go/quic-go/internal/utils/linkedlist
github.com/quic-go/quic-go/internal/utils/ringbuffer
github.com/quic-go/quic-go/internal/wire github.com/quic-go/quic-go/internal/wire
github.com/quic-go/quic-go/logging github.com/quic-go/quic-go/logging
github.com/quic-go/quic-go/quicvarint github.com/quic-go/quic-go/quicvarint
@ -153,7 +154,7 @@ golang.org/x/crypto/salsa20/salsa
# golang.org/x/exp v0.0.0-20221205204356-47842c84f3db # golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/exp/constraints golang.org/x/exp/constraints
# golang.org/x/mod v0.8.0 # golang.org/x/mod v0.10.0
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/mod/internal/lazyregexp golang.org/x/mod/internal/lazyregexp
golang.org/x/mod/modfile golang.org/x/mod/modfile
@ -189,13 +190,14 @@ golang.org/x/text/secure/bidirule
golang.org/x/text/transform golang.org/x/text/transform
golang.org/x/text/unicode/bidi golang.org/x/text/unicode/bidi
golang.org/x/text/unicode/norm golang.org/x/text/unicode/norm
# golang.org/x/tools v0.6.0 # golang.org/x/tools v0.9.1
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/tools/go/ast/astutil golang.org/x/tools/go/ast/astutil
golang.org/x/tools/go/ast/inspector golang.org/x/tools/go/ast/inspector
golang.org/x/tools/go/gcexportdata golang.org/x/tools/go/gcexportdata
golang.org/x/tools/go/internal/packagesdriver golang.org/x/tools/go/internal/packagesdriver
golang.org/x/tools/go/packages golang.org/x/tools/go/packages
golang.org/x/tools/go/types/objectpath
golang.org/x/tools/imports golang.org/x/tools/imports
golang.org/x/tools/internal/event golang.org/x/tools/internal/event
golang.org/x/tools/internal/event/core golang.org/x/tools/internal/event/core