Commit be36fec7 authored by Matthew Holt's avatar Matthew Holt

vendor: Update quic-go

parent 49e98a15
...@@ -25,7 +25,7 @@ type SentPacketHandler interface { ...@@ -25,7 +25,7 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface { type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame GetAckFrame() *frames.AckFrame
......
...@@ -12,7 +12,7 @@ var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet n ...@@ -12,7 +12,7 @@ var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet n
type receivedPacketHandler struct { type receivedPacketHandler struct {
largestObserved protocol.PacketNumber largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber lowerLimit protocol.PacketNumber
largestObservedReceivedTime time.Time largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
...@@ -39,33 +39,29 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe ...@@ -39,33 +39,29 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber return errInvalidPacketNumber
} }
if packetNumber > h.ignorePacketsBelow {
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
}
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now() h.largestObservedReceivedTime = time.Now()
} }
h.maybeQueueAck(packetNumber, shouldInstigateAck) if packetNumber <= h.lowerLimit {
return nil
}
func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) error {
// ignore if StopWaiting is unneeded, because we already received a StopWaiting with a higher LeastUnacked
if h.ignorePacketsBelow >= f.LeastUnacked {
return nil return nil
} }
h.ignorePacketsBelow = f.LeastUnacked - 1 if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
h.packetHistory.DeleteBelow(f.LeastUnacked) }
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil return nil
} }
// SetLowerLimit sets a lower limit for acking packets.
// Packets with packet numbers smaller or equal than p will not be acked.
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
h.lowerLimit = p
h.packetHistory.DeleteUpTo(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
......
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
// The receivedPacketHistory stores if a packet number has already been received.
// It does not store packet contents.
type receivedPacketHistory struct { type receivedPacketHistory struct {
ranges *utils.PacketIntervalList ranges *utils.PacketIntervalList
...@@ -84,20 +86,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error { ...@@ -84,20 +86,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return nil return nil
} }
// DeleteBelow deletes all entries below the leastUnacked packet number // DeleteUpTo deletes all entries up to (and including) p
func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) { func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked) h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
nextEl := h.ranges.Front() nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl { for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next() nextEl = el.Next()
if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End { if p >= el.Value.Start && p < el.Value.End {
for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range for i := el.Value.Start; i <= p; i++ { // adjust start value of a range
delete(h.receivedPacketNumbers, i) delete(h.receivedPacketNumbers, i)
} }
el.Value.Start = leastUnacked el.Value.Start = p + 1
} else if el.Value.End < leastUnacked { // delete a whole range } else if el.Value.End <= p { // delete a whole range
for i := el.Value.Start; i <= el.Value.End; i++ { for i := el.Value.Start; i <= el.Value.End; i++ {
delete(h.receivedPacketNumbers, i) delete(h.receivedPacketNumbers, i)
} }
......
...@@ -81,7 +81,7 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, ...@@ -81,7 +81,7 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber,
return nil, err return nil, err
} }
if leastUnackedDelta > uint64(packetNumber) { if leastUnackedDelta >= uint64(packetNumber) {
return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta") return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta")
} }
......
...@@ -3,6 +3,7 @@ package frames ...@@ -3,6 +3,7 @@ package frames
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
...@@ -70,11 +71,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { ...@@ -70,11 +71,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) {
} }
if dataLen != 0 { if dataLen != 0 {
frame.Data = make([]byte, dataLen) frame.Data = make([]byte, dataLen)
n, err := r.Read(frame.Data) if _, err := io.ReadFull(r, frame.Data); err != nil {
if n != int(dataLen) {
return nil, errors.New("BUG: StreamFrame could not read dataLen bytes")
}
if err != nil {
return nil, err return nil, err
} }
} }
......
...@@ -54,11 +54,15 @@ type cryptoSetupServer struct { ...@@ -54,11 +54,15 @@ type cryptoSetupServer struct {
var _ CryptoSetup = &cryptoSetupServer{} var _ CryptoSetup = &cryptoSetupServer{}
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO // ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO.
// this is an expiremnt implemented by Chrome in QUIC 36, which we don't support // This is an experiment implemented by Chrome in QUIC 36, which we don't support.
// TODO: remove this when dropping support for QUIC 36 // TODO: remove this when dropping support for QUIC 36
var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported") var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported")
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server // NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup( func NewCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
...@@ -120,6 +124,9 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][] ...@@ -120,6 +124,9 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment { if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment {
return false, ErrHOLExperiment return false, ErrHOLExperiment
} }
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI] sniSlice, ok := cryptoData[TagSNI]
if !ok { if !ok {
......
...@@ -87,7 +87,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { ...@@ -87,7 +87,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
v := data[Tag(t)] v := data[Tag(t)]
b.Write(v) b.Write(v)
offset += uint32(len(v)) offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], t) binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset) binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
} }
...@@ -95,14 +95,16 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) { ...@@ -95,14 +95,16 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
copy(b.Bytes()[indexStart:], indexData) copy(b.Bytes()[indexStart:], indexData)
} }
func (h *HandshakeMessage) getTagsSorted() []uint32 { func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]uint32, len(h.Data)) tags := make([]Tag, len(h.Data))
i := 0 i := 0
for t := range h.Data { for t := range h.Data {
tags[i] = uint32(t) tags[i] = t
i++ i++
} }
sort.Sort(utils.Uint32Slice(tags)) sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags return tags
} }
......
...@@ -54,6 +54,9 @@ const ( ...@@ -54,6 +54,9 @@ const (
// Chrome experiment (see https://codereview.chromium.org/2115033002) // Chrome experiment (see https://codereview.chromium.org/2115033002)
// unsupported by quic-go // unsupported by quic-go
TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24 TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token // TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16 TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
......
package integrationtests
import (
"crypto/md5"
"math/rand"
"time"
)
type dataManager struct {
data []byte
md5 []byte
}
func (m *dataManager) GenerateData(len int) error {
m.data = make([]byte, len)
r := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
_, err := r.Read(m.data)
if err != nil {
return err
}
sum := md5.Sum(m.data)
m.md5 = sum[:]
return nil
}
func (m *dataManager) GetData() []byte {
return m.data
}
func (m *dataManager) GetMD5() []byte {
return m.md5
}
package handshaketests
import (
"crypto/tls"
"fmt"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/proxy"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake integration tets", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
for {
_, _ = server.Accept()
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
expectedDuration := time.Duration(num) * rtt
Expect(testDuration).To(SatisfyAll(
BeNumerically(">=", expectedDuration),
BeNumerically("<", expectedDuration+rtt),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
Expect(len(protocol.SupportedVersions)).To(BeNumerically(">", 1))
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require an STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})
package testlog
import (
"flag"
"log"
"os"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
logFileName string // the log file set in the ginkgo flags
logFile *os.File
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if len(logFileName) > 0 {
var err error
logFile, err = os.Create("./log.txt")
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug)
}
})
var _ = AfterEach(func() {
if len(logFileName) > 0 {
_ = logFile.Close()
}
})
package testserver
import (
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
)
var (
PRData = GeneratePRData(dataLen)
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
port string
)
func init() {
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
sl := r.URL.Query().Get("len")
if sl != "" {
var err error
l, err := strconv.Atoi(sl)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(GeneratePRData(l))
Expect(err).NotTo(HaveOccurred())
} else {
_, err := w.Write(PRData)
Expect(err).NotTo(HaveOccurred())
}
})
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := w.Write(PRDataLong)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := io.WriteString(w, "Hello, World!\n")
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
body, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(body)
Expect(err).NotTo(HaveOccurred())
})
}
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
func GeneratePRData(l int) []byte {
res := make([]byte, l)
seed := uint64(1)
for i := 0; i < l; i++ {
seed = seed * 48271 % 2147483647
res[i] = byte(seed)
}
return res
}
func StartQuicServer() {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
Expect(err).NotTo(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
go func() {
defer GinkgoRecover()
server.Serve(conn)
}()
}
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
}
func Port() string {
return port
}
package quic package quic
import ( import (
"context"
"io" "io"
"net" "net"
"time" "time"
...@@ -22,6 +23,10 @@ type Stream interface { ...@@ -22,6 +23,10 @@ type Stream interface {
StreamID() protocol.StreamID StreamID() protocol.StreamID
// Reset closes the stream with an error. // Reset closes the stream with an error.
Reset(error) Reset(error)
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// SetReadDeadline sets the deadline for future Read calls and // SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call. // any currently-blocked Read call.
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
...@@ -43,7 +48,7 @@ type Session interface { ...@@ -43,7 +48,7 @@ type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available. // AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server). // Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error) AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peeer's concurrent stream limit is reached. // OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
// New streams always have the smallest possible stream ID. // New streams always have the smallest possible stream ID.
// TODO: Enable testing for the special error // TODO: Enable testing for the special error
OpenStream() (Stream, error) OpenStream() (Stream, error)
...@@ -56,9 +61,9 @@ type Session interface { ...@@ -56,9 +61,9 @@ type Session interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error Close(error) error
// WaitUntilClosed() blocks until the session is closed. // The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
WaitUntilClosed() Context() context.Context
} }
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. // A NonFWSession is a QUIC connection between two peers half-way through the handshake.
......
...@@ -127,10 +127,3 @@ func WriteUint24(b *bytes.Buffer, i uint32) { ...@@ -127,10 +127,3 @@ func WriteUint24(b *bytes.Buffer, i uint32) {
func WriteUint16(b *bytes.Buffer, i uint16) { func WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i), uint8(i >> 8)}) b.Write([]byte{uint8(i), uint8(i >> 8)})
} }
// Uint32Slice attaches the methods of sort.Interface to []uint32, sorting in increasing order.
type Uint32Slice []uint32
func (s Uint32Slice) Len() int { return len(s) }
func (s Uint32Slice) Less(i, j int) bool { return s[i] < s[j] }
func (s Uint32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
...@@ -8,6 +8,7 @@ const ( ...@@ -8,6 +8,7 @@ const (
Version35 VersionNumber = 35 + iota Version35 VersionNumber = 35 + iota
Version36 Version36
Version37 Version37
Version38
VersionWhatever VersionNumber = 0 // for when the version doesn't matter VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnsupported VersionNumber = -1 VersionUnsupported VersionNumber = -1
) )
...@@ -15,7 +16,10 @@ const ( ...@@ -15,7 +16,10 @@ const (
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
// must be in sorted descending order // must be in sorted descending order
var SupportedVersions = []VersionNumber{ var SupportedVersions = []VersionNumber{
Version37, Version36, Version35, Version38,
Version37,
Version36,
Version35,
} }
// VersionNumberToTag maps version numbers ('32') to tags ('Q032') // VersionNumberToTag maps version numbers ('32') to tags ('Q032')
......
...@@ -3,6 +3,7 @@ package quic ...@@ -3,6 +3,7 @@ package quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
...@@ -166,9 +167,7 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub ...@@ -166,9 +167,7 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
// see https://github.com/lucas-clemente/quic-go/issues/232 // see https://github.com/lucas-clemente/quic-go/issues/232
if !header.VersionFlag && !header.ResetFlag { if !header.VersionFlag && !header.ResetFlag {
header.DiversificationNonce = make([]byte, 32) header.DiversificationNonce = make([]byte, 32)
// this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil {
_, err = b.Read(header.DiversificationNonce)
if err != nil {
return nil, err return nil, err
} }
} }
......
package quic package quic
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
...@@ -78,11 +79,11 @@ type session struct { ...@@ -78,11 +79,11 @@ type session struct {
sendingScheduled chan struct{} sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate. // closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError closeChan chan closeError
// runClosed is closed once the run loop exits
// it is used to block Close() and WaitUntilClosed()
runClosed chan struct{}
closeOnce sync.Once closeOnce sync.Once
ctx context.Context
ctxCancel context.CancelFunc
// when we receive too many undecryptable packets during the handshake, we send a Public reset // when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed // but only after a time of protocol.PublicResetTimeout has passed
undecryptablePackets []*receivedPacket undecryptablePackets []*receivedPacket
...@@ -167,12 +168,12 @@ func (s *session) setup( ...@@ -167,12 +168,12 @@ func (s *session) setup(
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3) handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan s.handshakeChan = handshakeChan
s.runClosed = make(chan struct{})
s.handshakeCompleteChan = make(chan error, 1) s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.timer = utils.NewTimer() s.timer = utils.NewTimer()
now := time.Now() now := time.Now()
...@@ -333,12 +334,12 @@ runLoop: ...@@ -333,12 +334,12 @@ runLoop:
s.handshakeChan <- handshakeEvent{err: closeErr.err} s.handshakeChan <- handshakeEvent{err: closeErr.err}
} }
s.handleCloseError(closeErr) s.handleCloseError(closeErr)
close(s.runClosed) defer s.ctxCancel()
return closeErr.err return closeErr.err
} }
func (s *session) WaitUntilClosed() { func (s *session) Context() context.Context {
<-s.runClosed return s.ctx
} }
func (s *session) maybeResetTimer() { func (s *session) maybeResetTimer() {
...@@ -445,7 +446,9 @@ func (s *session) handleFrames(fs []frames.Frame) error { ...@@ -445,7 +446,9 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.GoawayFrame: case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames") err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame: case *frames.StopWaitingFrame:
err = s.receivedPacketHandler.ReceivedStopWaiting(frame) // LeastUnacked is guaranteed to have LeastUnacked > 0
// therefore this will never underflow
s.receivedPacketHandler.SetLowerLimit(frame.LeastUnacked - 1)
case *frames.RstStreamFrame: case *frames.RstStreamFrame:
err = s.handleRstStreamFrame(frame) err = s.handleRstStreamFrame(frame)
case *frames.WindowUpdateFrame: case *frames.WindowUpdateFrame:
...@@ -543,7 +546,7 @@ func (s *session) closeRemote(e error) { ...@@ -543,7 +546,7 @@ func (s *session) closeRemote(e error) {
// It waits until the run loop has stopped before returning // It waits until the run loop has stopped before returning
func (s *session) Close(e error) error { func (s *session) Close(e error) error {
s.closeLocal(e) s.closeLocal(e)
<-s.runClosed <-s.ctx.Done()
return nil return nil
} }
...@@ -575,7 +578,9 @@ func (s *session) handleCloseError(closeErr closeError) error { ...@@ -575,7 +578,9 @@ func (s *session) handleCloseError(closeErr closeError) error {
return nil return nil
} }
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment { if quicErr.ErrorCode == qerr.DecryptionFailure ||
quicErr == handshake.ErrHOLExperiment ||
quicErr == handshake.ErrNSTPExperiment {
return s.sendPublicReset(s.lastRcvdPacketNumber) return s.sendPublicReset(s.lastRcvdPacketNumber)
} }
return s.sendConnectionClose(quicErr) return s.sendConnectionClose(quicErr)
......
package quic package quic
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
...@@ -19,6 +20,9 @@ import ( ...@@ -19,6 +20,9 @@ import (
type stream struct { type stream struct {
mutex sync.Mutex mutex sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
streamID protocol.StreamID streamID protocol.StreamID
onData func() onData func()
// onReset is a callback that should send a RST_STREAM // onReset is a callback that should send a RST_STREAM
...@@ -55,6 +59,8 @@ type stream struct { ...@@ -55,6 +59,8 @@ type stream struct {
flowControlManager flowcontrol.FlowControlManager flowControlManager flowcontrol.FlowControlManager
} }
var _ Stream = &stream{}
type deadlineError struct{} type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" } func (deadlineError) Error() string { return "deadline exceeded" }
...@@ -68,7 +74,7 @@ func newStream(StreamID protocol.StreamID, ...@@ -68,7 +74,7 @@ func newStream(StreamID protocol.StreamID,
onData func(), onData func(),
onReset func(protocol.StreamID, protocol.ByteCount), onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream { flowControlManager flowcontrol.FlowControlManager) *stream {
return &stream{ s := &stream{
onData: onData, onData: onData,
onReset: onReset, onReset: onReset,
streamID: StreamID, streamID: StreamID,
...@@ -77,6 +83,8 @@ func newStream(StreamID protocol.StreamID, ...@@ -77,6 +83,8 @@ func newStream(StreamID protocol.StreamID,
readChan: make(chan struct{}, 1), readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1), writeChan: make(chan struct{}, 1),
} }
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
} }
// Read implements io.Reader. It is not thread safe! // Read implements io.Reader. It is not thread safe!
...@@ -180,6 +188,9 @@ func (s *stream) Write(p []byte) (int, error) { ...@@ -180,6 +188,9 @@ func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() || s.err != nil { if s.resetLocally.Get() || s.err != nil {
return 0, s.err return 0, s.err
} }
if s.finishedWriting.Get() {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
...@@ -254,6 +265,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { ...@@ -254,6 +265,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
// Close implements io.Closer // Close implements io.Closer
func (s *stream) Close() error { func (s *stream) Close() error {
s.finishedWriting.Set(true) s.finishedWriting.Set(true)
s.ctxCancel()
s.onData() s.onData()
return nil return nil
} }
...@@ -349,6 +361,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) { ...@@ -349,6 +361,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) {
func (s *stream) Cancel(err error) { func (s *stream) Cancel(err error) {
s.mutex.Lock() s.mutex.Lock()
s.cancelled.Set(true) s.cancelled.Set(true)
s.ctxCancel()
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
...@@ -365,6 +378,7 @@ func (s *stream) Reset(err error) { ...@@ -365,6 +378,7 @@ func (s *stream) Reset(err error) {
} }
s.mutex.Lock() s.mutex.Lock()
s.resetLocally.Set(true) s.resetLocally.Set(true)
s.ctxCancel()
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
...@@ -385,6 +399,7 @@ func (s *stream) RegisterRemoteError(err error) { ...@@ -385,6 +399,7 @@ func (s *stream) RegisterRemoteError(err error) {
} }
s.mutex.Lock() s.mutex.Lock()
s.resetRemotely.Set(true) s.resetRemotely.Set(true)
s.ctxCancel()
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
...@@ -409,6 +424,10 @@ func (s *stream) finished() bool { ...@@ -409,6 +424,10 @@ func (s *stream) finished() bool {
(s.finishedWriteAndSentFin() && s.resetRemotely.Get()) (s.finishedWriteAndSentFin() && s.resetRemotely.Get())
} }
func (s *stream) Context() context.Context {
return s.ctx
}
func (s *stream) StreamID() protocol.StreamID { func (s *stream) StreamID() protocol.StreamID {
return s.streamID return s.streamID
} }
...@@ -129,7 +129,7 @@ ...@@ -129,7 +129,7 @@
"importpath": "github.com/lucas-clemente/quic-go", "importpath": "github.com/lucas-clemente/quic-go",
"repository": "https://github.com/lucas-clemente/quic-go", "repository": "https://github.com/lucas-clemente/quic-go",
"vcs": "git", "vcs": "git",
"revision": "3e012f77c8d4c06297b2fa40261e82193134b224", "revision": "a9e2a28315406f825cdfe41f8652110addeb84a5",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment