Commit be36fec7 authored by Matthew Holt's avatar Matthew Holt

vendor: Update quic-go

parent 49e98a15
......@@ -25,7 +25,7 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame
......
......@@ -12,7 +12,7 @@ var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet n
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber
lowerLimit protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
......@@ -39,33 +39,29 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber
}
if packetNumber > h.ignorePacketsBelow {
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
}
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) error {
// ignore if StopWaiting is unneeded, because we already received a StopWaiting with a higher LeastUnacked
if h.ignorePacketsBelow >= f.LeastUnacked {
if packetNumber <= h.lowerLimit {
return nil
}
h.ignorePacketsBelow = f.LeastUnacked - 1
h.packetHistory.DeleteBelow(f.LeastUnacked)
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
// SetLowerLimit sets a lower limit for acking packets.
// Packets with packet numbers smaller or equal than p will not be acked.
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
h.lowerLimit = p
h.packetHistory.DeleteUpTo(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
......
......@@ -7,6 +7,8 @@ import (
"github.com/lucas-clemente/quic-go/qerr"
)
// The receivedPacketHistory stores if a packet number has already been received.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
......@@ -84,20 +86,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return nil
}
// DeleteBelow deletes all entries below the leastUnacked packet number
func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked)
// DeleteUpTo deletes all entries up to (and including) p
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End {
for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range
if p >= el.Value.Start && p < el.Value.End {
for i := el.Value.Start; i <= p; i++ { // adjust start value of a range
delete(h.receivedPacketNumbers, i)
}
el.Value.Start = leastUnacked
} else if el.Value.End < leastUnacked { // delete a whole range
el.Value.Start = p + 1
} else if el.Value.End <= p { // delete a whole range
for i := el.Value.Start; i <= el.Value.End; i++ {
delete(h.receivedPacketNumbers, i)
}
......
......@@ -81,7 +81,7 @@ func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber,
return nil, err
}
if leastUnackedDelta > uint64(packetNumber) {
if leastUnackedDelta >= uint64(packetNumber) {
return nil, qerr.Error(qerr.InvalidStopWaitingData, "invalid LeastUnackedDelta")
}
......
......@@ -3,6 +3,7 @@ package frames
import (
"bytes"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
......@@ -70,11 +71,7 @@ func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) {
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
n, err := r.Read(frame.Data)
if n != int(dataLen) {
return nil, errors.New("BUG: StreamFrame could not read dataLen bytes")
}
if err != nil {
if _, err := io.ReadFull(r, frame.Data); err != nil {
return nil, err
}
}
......
......@@ -54,11 +54,15 @@ type cryptoSetupServer struct {
var _ CryptoSetup = &cryptoSetupServer{}
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO
// this is an expiremnt implemented by Chrome in QUIC 36, which we don't support
// ErrHOLExperiment is returned when the client sends the FHL2 tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 36, which we don't support.
// TODO: remove this when dropping support for QUIC 36
var ErrHOLExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "HOL experiment. Unsupported")
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup(
connID protocol.ConnectionID,
......@@ -120,6 +124,9 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
if _, isHOLExperiment := cryptoData[TagFHL2]; isHOLExperiment {
return false, ErrHOLExperiment
}
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI]
if !ok {
......
......@@ -87,7 +87,7 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
v := data[Tag(t)]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], t)
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
}
......@@ -95,14 +95,16 @@ func (h HandshakeMessage) Write(b *bytes.Buffer) {
copy(b.Bytes()[indexStart:], indexData)
}
func (h *HandshakeMessage) getTagsSorted() []uint32 {
tags := make([]uint32, len(h.Data))
func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]Tag, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = uint32(t)
tags[i] = t
i++
}
sort.Sort(utils.Uint32Slice(tags))
sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags
}
......
......@@ -54,6 +54,9 @@ const (
// Chrome experiment (see https://codereview.chromium.org/2115033002)
// unsupported by quic-go
TagFHL2 Tag = 'F' + 'H'<<8 + 'L'<<16 + '2'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
......
package integrationtests
import (
"crypto/md5"
"math/rand"
"time"
)
type dataManager struct {
data []byte
md5 []byte
}
func (m *dataManager) GenerateData(len int) error {
m.data = make([]byte, len)
r := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
_, err := r.Read(m.data)
if err != nil {
return err
}
sum := md5.Sum(m.data)
m.md5 = sum[:]
return nil
}
func (m *dataManager) GetData() []byte {
return m.data
}
func (m *dataManager) GetMD5() []byte {
return m.md5
}
package handshaketests
import (
"crypto/tls"
"fmt"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/proxy"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake integration tets", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
for {
_, _ = server.Accept()
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
expectedDuration := time.Duration(num) * rtt
Expect(testDuration).To(SatisfyAll(
BeNumerically(">=", expectedDuration),
BeNumerically("<", expectedDuration+rtt),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
Expect(len(protocol.SupportedVersions)).To(BeNumerically(">", 1))
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require an STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})
package testlog
import (
"flag"
"log"
"os"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
logFileName string // the log file set in the ginkgo flags
logFile *os.File
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if len(logFileName) > 0 {
var err error
logFile, err = os.Create("./log.txt")
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug)
}
})
var _ = AfterEach(func() {
if len(logFileName) > 0 {
_ = logFile.Close()
}
})
package testserver
import (
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
)
var (
PRData = GeneratePRData(dataLen)
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
port string
)
func init() {
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
sl := r.URL.Query().Get("len")
if sl != "" {
var err error
l, err := strconv.Atoi(sl)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(GeneratePRData(l))
Expect(err).NotTo(HaveOccurred())
} else {
_, err := w.Write(PRData)
Expect(err).NotTo(HaveOccurred())
}
})
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := w.Write(PRDataLong)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := io.WriteString(w, "Hello, World!\n")
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
body, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(body)
Expect(err).NotTo(HaveOccurred())
})
}
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
func GeneratePRData(l int) []byte {
res := make([]byte, l)
seed := uint64(1)
for i := 0; i < l; i++ {
seed = seed * 48271 % 2147483647
res[i] = byte(seed)
}
return res
}
func StartQuicServer() {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
Expect(err).NotTo(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
go func() {
defer GinkgoRecover()
server.Serve(conn)
}()
}
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
}
func Port() string {
return port
}
package quic
import (
"context"
"io"
"net"
"time"
......@@ -22,6 +23,10 @@ type Stream interface {
StreamID() protocol.StreamID
// Reset closes the stream with an error.
Reset(error)
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
......@@ -43,7 +48,7 @@ type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// Since stream 1 is reserved for the crypto stream, the first stream is either 2 (for a client) or 3 (for a server).
AcceptStream() (Stream, error)
// OpenStream opens a new QUIC stream, returning a special error when the peeer's concurrent stream limit is reached.
// OpenStream opens a new QUIC stream, returning a special error when the peer's concurrent stream limit is reached.
// New streams always have the smallest possible stream ID.
// TODO: Enable testing for the special error
OpenStream() (Stream, error)
......@@ -56,9 +61,9 @@ type Session interface {
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
......
......@@ -127,10 +127,3 @@ func WriteUint24(b *bytes.Buffer, i uint32) {
func WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i), uint8(i >> 8)})
}
// Uint32Slice attaches the methods of sort.Interface to []uint32, sorting in increasing order.
type Uint32Slice []uint32
func (s Uint32Slice) Len() int { return len(s) }
func (s Uint32Slice) Less(i, j int) bool { return s[i] < s[j] }
func (s Uint32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
......@@ -8,6 +8,7 @@ const (
Version35 VersionNumber = 35 + iota
Version36
Version37
Version38
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnsupported VersionNumber = -1
)
......@@ -15,7 +16,10 @@ const (
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{
Version37, Version36, Version35,
Version38,
Version37,
Version36,
Version35,
}
// VersionNumberToTag maps version numbers ('32') to tags ('Q032')
......
......@@ -3,6 +3,7 @@ package quic
import (
"bytes"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
......@@ -166,9 +167,7 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
// see https://github.com/lucas-clemente/quic-go/issues/232
if !header.VersionFlag && !header.ResetFlag {
header.DiversificationNonce = make([]byte, 32)
// this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number
_, err = b.Read(header.DiversificationNonce)
if err != nil {
if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil {
return nil, err
}
}
......
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
......@@ -78,11 +79,11 @@ type session struct {
sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError
// runClosed is closed once the run loop exits
// it is used to block Close() and WaitUntilClosed()
runClosed chan struct{}
closeOnce sync.Once
ctx context.Context
ctxCancel context.CancelFunc
// when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed
undecryptablePackets []*receivedPacket
......@@ -167,12 +168,12 @@ func (s *session) setup(
s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan
s.runClosed = make(chan struct{})
s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.timer = utils.NewTimer()
now := time.Now()
......@@ -333,12 +334,12 @@ runLoop:
s.handshakeChan <- handshakeEvent{err: closeErr.err}
}
s.handleCloseError(closeErr)
close(s.runClosed)
defer s.ctxCancel()
return closeErr.err
}
func (s *session) WaitUntilClosed() {
<-s.runClosed
func (s *session) Context() context.Context {
return s.ctx
}
func (s *session) maybeResetTimer() {
......@@ -445,7 +446,9 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame:
err = s.receivedPacketHandler.ReceivedStopWaiting(frame)
// LeastUnacked is guaranteed to have LeastUnacked > 0
// therefore this will never underflow
s.receivedPacketHandler.SetLowerLimit(frame.LeastUnacked - 1)
case *frames.RstStreamFrame:
err = s.handleRstStreamFrame(frame)
case *frames.WindowUpdateFrame:
......@@ -543,7 +546,7 @@ func (s *session) closeRemote(e error) {
// It waits until the run loop has stopped before returning
func (s *session) Close(e error) error {
s.closeLocal(e)
<-s.runClosed
<-s.ctx.Done()
return nil
}
......@@ -575,7 +578,9 @@ func (s *session) handleCloseError(closeErr closeError) error {
return nil
}
if quicErr.ErrorCode == qerr.DecryptionFailure || quicErr == handshake.ErrHOLExperiment {
if quicErr.ErrorCode == qerr.DecryptionFailure ||
quicErr == handshake.ErrHOLExperiment ||
quicErr == handshake.ErrNSTPExperiment {
return s.sendPublicReset(s.lastRcvdPacketNumber)
}
return s.sendConnectionClose(quicErr)
......
package quic
import (
"context"
"fmt"
"io"
"net"
......@@ -19,6 +20,9 @@ import (
type stream struct {
mutex sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
streamID protocol.StreamID
onData func()
// onReset is a callback that should send a RST_STREAM
......@@ -55,6 +59,8 @@ type stream struct {
flowControlManager flowcontrol.FlowControlManager
}
var _ Stream = &stream{}
type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" }
......@@ -68,7 +74,7 @@ func newStream(StreamID protocol.StreamID,
onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream {
return &stream{
s := &stream{
onData: onData,
onReset: onReset,
streamID: StreamID,
......@@ -77,6 +83,8 @@ func newStream(StreamID protocol.StreamID,
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
// Read implements io.Reader. It is not thread safe!
......@@ -180,6 +188,9 @@ func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() || s.err != nil {
return 0, s.err
}
if s.finishedWriting.Get() {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if len(p) == 0 {
return 0, nil
}
......@@ -254,6 +265,7 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
// Close implements io.Closer
func (s *stream) Close() error {
s.finishedWriting.Set(true)
s.ctxCancel()
s.onData()
return nil
}
......@@ -349,6 +361,7 @@ func (s *stream) CloseRemote(offset protocol.ByteCount) {
func (s *stream) Cancel(err error) {
s.mutex.Lock()
s.cancelled.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
......@@ -365,6 +378,7 @@ func (s *stream) Reset(err error) {
}
s.mutex.Lock()
s.resetLocally.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
......@@ -385,6 +399,7 @@ func (s *stream) RegisterRemoteError(err error) {
}
s.mutex.Lock()
s.resetRemotely.Set(true)
s.ctxCancel()
// errors must not be changed!
if s.err == nil {
s.err = err
......@@ -409,6 +424,10 @@ func (s *stream) finished() bool {
(s.finishedWriteAndSentFin() && s.resetRemotely.Get())
}
func (s *stream) Context() context.Context {
return s.ctx
}
func (s *stream) StreamID() protocol.StreamID {
return s.streamID
}
......@@ -129,7 +129,7 @@
"importpath": "github.com/lucas-clemente/quic-go",
"repository": "https://github.com/lucas-clemente/quic-go",
"vcs": "git",
"revision": "3e012f77c8d4c06297b2fa40261e82193134b224",
"revision": "a9e2a28315406f825cdfe41f8652110addeb84a5",
"branch": "master",
"notests": true
},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment