Commit a48e4ecb authored by Matthew Holt's avatar Matthew Holt

vendor: Update dependencies

parent 74940af6
...@@ -174,11 +174,11 @@ func sizeFixed32(x uint64) int { ...@@ -174,11 +174,11 @@ func sizeFixed32(x uint64) int {
// This is the format used for the sint64 protocol buffer type. // This is the format used for the sint64 protocol buffer type.
func (p *Buffer) EncodeZigzag64(x uint64) error { func (p *Buffer) EncodeZigzag64(x uint64) error {
// use signed number to get arithmetic right shift. // use signed number to get arithmetic right shift.
return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return p.EncodeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
func sizeZigzag64(x uint64) int { func sizeZigzag64(x uint64) int {
return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return sizeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
// EncodeZigzag32 writes a zigzag-encoded 32-bit integer // EncodeZigzag32 writes a zigzag-encoded 32-bit integer
......
...@@ -865,7 +865,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error { ...@@ -865,7 +865,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
return p.readStruct(fv, terminator) return p.readStruct(fv, terminator)
case reflect.Uint32: case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
fv.SetUint(uint64(x)) fv.SetUint(x)
return nil return nil
} }
case reflect.Uint64: case reflect.Uint64:
......
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
// //
// Overview // Overview
// //
// The Conn type represents a WebSocket connection. A server application uses // The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an Upgrader object with a HTTP request handler // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
// to get a pointer to a Conn:
// //
// var upgrader = websocket.Upgrader{ // var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024, // ReadBufferSize: 1024,
...@@ -33,7 +32,7 @@ ...@@ -33,7 +32,7 @@
// if err != nil { // if err != nil {
// return // return
// } // }
// if err = conn.WriteMessage(messageType, p); err != nil { // if err := conn.WriteMessage(messageType, p); err != nil {
// return err // return err
// } // }
// } // }
...@@ -147,9 +146,9 @@ ...@@ -147,9 +146,9 @@
// CheckOrigin: func(r *http.Request) bool { return true }, // CheckOrigin: func(r *http.Request) bool { return true },
// } // }
// //
// The deprecated Upgrade function does not enforce an origin policy. It's the // The deprecated package-level Upgrade function does not perform origin
// application's responsibility to check the Origin header before calling // checking. The application is responsible for checking the Origin header
// Upgrade. // before calling the Upgrade function.
// //
// Compression EXPERIMENTAL // Compression EXPERIMENTAL
// //
......
...@@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { ...@@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
} }
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
client.hub.register <- client client.hub.register <- client
// Allow collection of memory referenced by the caller by doing all work in
// new goroutines.
go client.writePump() go client.writePump()
client.readPump() go client.readPump()
} }
...@@ -9,12 +9,14 @@ import ( ...@@ -9,12 +9,14 @@ import (
"io" "io"
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
// WriteJSON writes the JSON encoding of v to the connection. // WriteJSON writes the JSON encoding of v as a message.
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
...@@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error { ...@@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2 return err2
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }
......
...@@ -230,10 +230,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade ...@@ -230,10 +230,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// This function is deprecated, use websocket.Upgrader instead. // Deprecated: Use websocket.Upgrader instead.
// //
// The application is responsible for checking the request origin before // Upgrade does not perform origin checking. The application is responsible for
// calling Upgrade. An example implementation of the same origin policy is: // checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
// //
// if req.Header.Get("Origin") != "http://"+req.Host { // if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403) // http.Error(w, "Origin not allowed", 403)
......
...@@ -111,14 +111,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) { ...@@ -111,14 +111,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
case escape: case escape:
escape = false escape = false
p[j] = b p[j] = b
j += 1 j++
case b == '\\': case b == '\\':
escape = true escape = true
case b == '"': case b == '"':
return string(p[:j]), s[i+1:] return string(p[:j]), s[i+1:]
default: default:
p[j] = b p[j] = b
j += 1 j++
} }
} }
return "", "" return "", ""
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
...@@ -26,5 +27,6 @@ type ReceivedPacketHandler interface { ...@@ -26,5 +27,6 @@ type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame GetAckFrame() *frames.AckFrame
} }
...@@ -8,13 +8,6 @@ import ( ...@@ -8,13 +8,6 @@ import (
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
var (
// ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct { type receivedPacketHandler struct {
...@@ -30,20 +23,14 @@ type receivedPacketHandler struct { ...@@ -30,20 +23,14 @@ type receivedPacketHandler struct {
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
ackQueued bool ackQueued bool
ackAlarm time.Time ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { func NewReceivedPacketHandler() ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback, ackSendDelay: protocol.AckSendDelay,
ackSendDelay: protocol.AckSendDelay,
} }
} }
...@@ -52,19 +39,10 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe ...@@ -52,19 +39,10 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber return errInvalidPacketNumber
} }
// if the packet number is smaller than the largest LeastUnacked value of a StopWaiting we received, we cannot detect if this packet has a duplicate number if packetNumber > h.ignorePacketsBelow {
// the packet has to be ignored anyway if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
if packetNumber <= h.ignorePacketsBelow { return err
return ErrPacketSmallerThanLastStopWaiting }
}
if h.packetHistory.IsDuplicate(packetNumber) {
return ErrDuplicatePacket
}
err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
return err
} }
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
...@@ -89,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) ...@@ -89,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
if shouldInstigateAck { if shouldInstigateAck {
...@@ -124,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber ...@@ -124,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
} else { } else {
if h.ackAlarm.IsZero() { if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay) h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
} }
} }
} }
...@@ -132,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber ...@@ -132,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
if h.ackQueued { if h.ackQueued {
// cancel the ack alarm // cancel the ack alarm
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
} }
} }
...@@ -164,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame { ...@@ -164,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
return ack return ack
} }
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
...@@ -2,9 +2,9 @@ package ackhandler ...@@ -2,9 +2,9 @@ package ackhandler
import ( import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type receivedPacketHistory struct { type receivedPacketHistory struct {
......
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
)
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
res := make([]frames.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f frames.Frame) bool {
switch f.(type) {
case *frames.StopWaitingFrame:
return false
case *frames.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []frames.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}
...@@ -7,9 +7,9 @@ import ( ...@@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (
...@@ -106,26 +106,27 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { ...@@ -106,26 +106,27 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
} }
} }
h.lastSentPacketNumber = packet.PacketNumber
now := time.Now() now := time.Now()
packet.SendTime = now
if packet.Length == 0 {
return errors.New("SentPacketHandler: packet cannot be empty")
}
h.bytesInFlight += packet.Length
h.lastSentPacketNumber = packet.PacketNumber packet.Frames = stripNonRetransmittableFrames(packet.Frames)
h.packetHistory.PushBack(*packet) isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
}
h.congestion.OnPacketSent( h.congestion.OnPacketSent(
now, now,
h.bytesInFlight, h.bytesInFlight,
packet.PacketNumber, packet.PacketNumber,
packet.Length, packet.Length,
true, /* TODO: is retransmittable */ isRetransmittable,
) )
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm()
return nil return nil
} }
...@@ -310,10 +311,11 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet { ...@@ -310,10 +311,11 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 { if len(h.retransmissionQueue) == 0 {
return nil return nil
} }
queueLen := len(h.retransmissionQueue) packet := h.retransmissionQueue[0]
// packets are usually NACKed in descending order. So use the slice as a stack // Shift the slice and don't retain anything that isn't needed.
packet := h.retransmissionQueue[queueLen-1] copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue = h.retransmissionQueue[:queueLen-1] h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet return packet
} }
...@@ -333,7 +335,11 @@ func (h *sentPacketHandler) SendingAllowed() bool { ...@@ -333,7 +335,11 @@ func (h *sentPacketHandler) SendingAllowed() bool {
h.bytesInFlight, h.bytesInFlight,
h.congestion.GetCongestionWindow()) h.congestion.GetCongestionWindow())
} }
return !(congestionLimited || maxTrackedLimited) // Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
} }
func (h *sentPacketHandler) retransmitOldestTwoPackets() { func (h *sentPacketHandler) retransmitOldestTwoPackets() {
......
...@@ -2,6 +2,7 @@ package quic ...@@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
...@@ -9,9 +10,9 @@ import ( ...@@ -9,9 +10,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type client struct { type client struct {
...@@ -24,6 +25,7 @@ type client struct { ...@@ -24,6 +25,7 @@ type client struct {
errorChan chan struct{} errorChan chan struct{}
handshakeChan <-chan handshakeEvent handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config config *Config
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has version negotiation completed yet
...@@ -39,7 +41,7 @@ var ( ...@@ -39,7 +41,7 @@ var (
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) { func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) { ...@@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Dial(udpConn, udpAddr, addr, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server. // DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -62,20 +68,33 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { ...@@ -62,20 +68,33 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return DialNonFWSecure(udpConn, udpAddr, addr, config) return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. // DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) { func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID() connID, err := utils.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostname, _, err := net.SplitHostPort(host) var hostname string
if err != nil { if tlsConf != nil {
return nil, err hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
} }
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
...@@ -83,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con ...@@ -83,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
tlsConf: tlsConf,
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
...@@ -93,15 +113,21 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con ...@@ -93,15 +113,21 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
return nil, err return nil, err
} }
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection() return c.session.(NonFWSession), c.establishSecureConnection()
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { func Dial(
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config) pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -112,16 +138,38 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config ...@@ -112,16 +138,38 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return sess, nil return sess, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config { func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig, Versions: versions,
Versions: versions, HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
} }
} }
...@@ -163,31 +211,46 @@ func (c *client) listen() { ...@@ -163,31 +211,46 @@ func (c *client) listen() {
} }
data = data[:n] data = data[:n]
err = c.handlePacket(addr, data) c.handlePacket(addr, data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
break
}
} }
} }
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now() rcvTime := time.Now()
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil { if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
return
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag { if c.versionNegotiated && hdr.VersionFlag {
return nil return
} }
// this is the first packet after the client sent a packet with the VersionFlag set // this is the first packet after the client sent a packet with the VersionFlag set
...@@ -198,7 +261,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { ...@@ -198,7 +261,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if hdr.VersionFlag { if hdr.VersionFlag {
// version negotiation packets have no payload // version negotiation packets have no payload
return c.handlePacketWithVersionFlag(hdr) if err := c.handlePacketWithVersionFlag(hdr); err != nil {
c.session.Close(err)
}
return
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
...@@ -207,7 +273,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { ...@@ -207,7 +273,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
data: packet[len(packet)-r.Len():], data: packet[len(packet)-r.Len():],
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil
} }
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
...@@ -246,6 +311,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e ...@@ -246,6 +311,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.connectionID,
c.tlsConf,
c.config, c.config,
negotiatedVersions, negotiatedVersions,
) )
......
...@@ -4,8 +4,8 @@ import ( ...@@ -4,8 +4,8 @@ import (
"math" "math"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// This cubic implementation is based on the one found in Chromiums's QUIC // This cubic implementation is based on the one found in Chromiums's QUIC
......
...@@ -3,8 +3,8 @@ package congestion ...@@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (
......
...@@ -3,8 +3,8 @@ package congestion ...@@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// Note(pwestin): the magic clamping numbers come from the original code in // Note(pwestin): the magic clamping numbers come from the original code in
......
...@@ -3,8 +3,8 @@ package congestion ...@@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
......
...@@ -3,7 +3,7 @@ package congestion ...@@ -3,7 +3,7 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
const ( const (
......
...@@ -102,3 +102,12 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -102,3 +102,12 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &c.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type entryType uint8 type entryType uint8
......
...@@ -107,15 +107,14 @@ func (c *certManager) Verify(hostname string) error { ...@@ -107,15 +107,14 @@ func (c *certManager) Verify(hostname string) error {
var opts x509.VerifyOptions var opts x509.VerifyOptions
if c.config != nil { if c.config != nil {
opts.Roots = c.config.RootCAs opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil { if c.config.Time == nil {
opts.CurrentTime = time.Now() opts.CurrentTime = time.Now()
} else { } else {
opts.CurrentTime = c.config.Time() opts.CurrentTime = c.config.Time()
} }
} else {
opts.DNSName = hostname
} }
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates // the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 { if len(c.chain) > 1 {
......
// +build go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}
// +build !go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
return c, nil
}
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"crypto/sha256" "crypto/sha256"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
func main() { func main() {
...@@ -24,7 +24,7 @@ func main() { ...@@ -24,7 +24,7 @@ func main() {
utils.SetLogTimeFormat("") utils.SetLogTimeFormat("")
hclient := &http.Client{ hclient := &http.Client{
Transport: &h2quic.QuicRoundTripper{}, Transport: &h2quic.RoundTripper{},
} }
var wg sync.WaitGroup var wg sync.WaitGroup
......
...@@ -31,10 +31,7 @@ func main() { ...@@ -31,10 +31,7 @@ func main() {
// Start a server that echos all data on the first stream opened by the client // Start a server that echos all data on the first stream opened by the client
func echoServer() error { func echoServer() error {
cfgServer := &quic.Config{ listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
TLSConfig: generateTLSConfig(),
}
listener, err := quic.ListenAddr(addr, cfgServer)
if err != nil { if err != nil {
return err return err
} }
...@@ -52,10 +49,7 @@ func echoServer() error { ...@@ -52,10 +49,7 @@ func echoServer() error {
} }
func clientMain() error { func clientMain() error {
cfgClient := &quic.Config{ session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
session, err := quic.DialAddr(addr, cfgClient)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type binds []string type binds []string
......
...@@ -7,9 +7,9 @@ import ( ...@@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowControlManager struct { type flowControlManager struct {
...@@ -78,7 +78,7 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset ...@@ -78,7 +78,7 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
...@@ -107,7 +107,7 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b ...@@ -107,7 +107,7 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
...@@ -157,6 +157,11 @@ func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (proto ...@@ -157,6 +157,11 @@ func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (proto
f.mutex.RLock() f.mutex.RLock()
defer f.mutex.RUnlock() defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID) flowController, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return 0, err return 0, err
......
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowController struct { type flowController struct {
......
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"errors" "errors"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
......
...@@ -3,8 +3,8 @@ package frames ...@@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A BlockedFrame in QUIC // A BlockedFrame in QUIC
......
...@@ -6,9 +6,9 @@ import ( ...@@ -6,9 +6,9 @@ import (
"io" "io"
"math" "math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A ConnectionCloseFrame in QUIC // A ConnectionCloseFrame in QUIC
......
...@@ -4,9 +4,9 @@ import ( ...@@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A GoawayFrame is a GOAWAY frame // A GoawayFrame is a GOAWAY frame
......
package frames package frames
import "github.com/lucas-clemente/quic-go/utils" import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received // LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) { func LogFrame(frame Frame, sent bool) {
......
...@@ -3,8 +3,8 @@ package frames ...@@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A RstStreamFrame in QUIC // A RstStreamFrame in QUIC
......
...@@ -4,9 +4,9 @@ import ( ...@@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StopWaitingFrame in QUIC // A StopWaitingFrame in QUIC
......
...@@ -4,9 +4,9 @@ import ( ...@@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StreamFrame of QUIC // A StreamFrame of QUIC
......
...@@ -3,8 +3,8 @@ package frames ...@@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A WindowUpdateFrame in QUIC // A WindowUpdateFrame in QUIC
......
...@@ -15,59 +15,72 @@ import ( ...@@ -15,59 +15,72 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// Client is a HTTP2 client doing QUIC requests type roundTripperOpts struct {
type Client struct { DisableCompression bool
mutex sync.RWMutex }
var dialAddr = quic.DialAddr
dialAddr func(hostname string, config *quic.Config) (quic.Session, error) // client is a HTTP2 client doing QUIC requests
config *quic.Config type client struct {
mutex sync.RWMutex
t *QuicRoundTripper tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
handshakeErr error handshakeErr error
dialChan chan struct{} // will be closed once the handshake is complete and the header stream has been opened dialOnce sync.Once
session quic.Session session quic.Session
headerStream quic.Stream headerStream quic.Stream
headerErr *qerr.QuicError headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
} }
var _ h2quicClient = &Client{} var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
}
// NewClient creates a new client // newClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { func newClient(
return &Client{ hostname string,
t: t, tlsConfig *tls.Config,
dialAddr: quic.DialAddr, opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted, encryptionLevel: protocol.EncryptionUnencrypted,
config: &quic.Config{ tlsConf: tlsConfig,
TLSConfig: tlsConfig, config: config,
RequestConnectionIDTruncation: true, opts: opts,
}, headerErrored: make(chan struct{}),
dialChan: make(chan struct{}),
} }
} }
// Dial dials the connection // dial dials the connection
func (c *Client) Dial() (err error) { func (c *client) dial() error {
defer func() { var err error
c.handshakeErr = err c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
close(c.dialChan)
}()
c.session, err = c.dialAddr(c.hostname, c.config)
if err != nil { if err != nil {
return err return err
} }
...@@ -82,10 +95,10 @@ func (c *Client) Dial() (err error) { ...@@ -82,10 +95,10 @@ func (c *Client) Dial() (err error) {
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
return return nil
} }
func (c *Client) handleHeaderStream() { func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream) h2framer := http2.NewFramer(nil, c.headerStream)
...@@ -111,7 +124,7 @@ func (c *Client) handleHeaderStream() { ...@@ -111,7 +124,7 @@ func (c *Client) handleHeaderStream() {
} }
c.mutex.RLock() c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock() c.mutex.RUnlock()
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
...@@ -122,41 +135,38 @@ func (c *Client) handleHeaderStream() { ...@@ -122,41 +135,38 @@ func (c *Client) handleHeaderStream() {
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) c.headerErr = qerr.Error(qerr.InternalError, err.Error())
} }
headerChan <- rsp responseChan <- rsp
} }
// stop all running request // stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock() close(c.headerErrored)
for _, responseChan := range c.responses {
close(responseChan)
}
c.mutex.Unlock()
} }
// Do executes a request and returns a response // Roundtrip executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one // TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" { if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme") return nil, errors.New("quic http2: unsupported scheme")
} }
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname) return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
} }
hasBody := (req.Body != nil) c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
// wait until the handshake is complete
<-c.dialChan
if c.handshakeErr != nil { if c.handshakeErr != nil {
return nil, c.handshakeErr return nil, c.handshakeErr
} }
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response) responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock() c.mutex.Lock()
...@@ -164,14 +174,14 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { ...@@ -164,14 +174,14 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Unlock() c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true requestedGzip = true
} }
// TODO: add support for trailers // TODO: add support for trailers
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
...@@ -198,15 +208,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { ...@@ -198,15 +208,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Lock() c.mutex.Lock()
delete(c.responses, dataStream.StreamID()) delete(c.responses, dataStream.StreamID())
c.mutex.Unlock() c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc: case err := <-resc:
bodySent = true bodySent = true
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-c.headerErrored:
// an error occured on the header stream
_ = c.CloseWithError(c.headerErr)
return nil, c.headerErr
} }
} }
...@@ -230,11 +240,10 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { ...@@ -230,11 +240,10 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
} }
res.Request = req res.Request = req
return res, nil return res, nil
} }
func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() { defer func() {
cerr := body.Close() cerr := body.Close()
if err == nil { if err == nil {
...@@ -252,8 +261,15 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e ...@@ -252,8 +261,15 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
} }
// Close closes the client // Close closes the client
func (c *Client) Close(e error) { func (c *client) CloseWithError(e error) error {
_ = c.session.Close(e) if c.session == nil {
return nil
}
return c.session.Close(e)
}
func (c *client) Close() error {
return c.CloseWithError(nil)
} }
// copied from net/transport.go // copied from net/transport.go
......
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type requestWriter struct { type requestWriter struct {
......
...@@ -8,8 +8,8 @@ import ( ...@@ -8,8 +8,8 @@ import (
"sync" "sync"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
......
...@@ -4,20 +4,23 @@ import ( ...@@ -4,20 +4,23 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
) )
type h2quicClient interface { type roundTripCloser interface {
Dial() error http.RoundTripper
Do(*http.Request) (*http.Response, error) io.Closer
} }
// QuicRoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from // DisableCompression, if true, prevents the Transport from
...@@ -34,13 +37,29 @@ type QuicRoundTripper struct { ...@@ -34,13 +37,29 @@ type QuicRoundTripper struct {
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
clients map[string]h2quicClient // QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
} }
var _ http.RoundTripper = &QuicRoundTripper{} var _ roundTripCloser = &RoundTripper{}
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTrip does a round trip // RoundTripOpt is like RoundTrip, but takes options.
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil { if req.URL == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL") return nil, errors.New("quic: nil Request.URL")
...@@ -76,35 +95,48 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) ...@@ -76,35 +95,48 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname) cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.Do(req) return cl.RoundTrip(req)
} }
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { // RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]h2quicClient) r.clients = make(map[string]roundTripCloser)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
client = NewClient(r, r.TLSClientConfig, hostname) if onlyCached {
err := client.Dial() return nil, ErrNoCachedConn
if err != nil {
return nil, err
} }
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client r.clients[hostname] = client
} }
return client, nil return client, nil
} }
func (r *QuicRoundTripper) disableCompression() bool { // Close closes the QUIC connections that this RoundTripper has used
return r.DisableCompression func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
} }
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {
......
...@@ -13,9 +13,9 @@ import ( ...@@ -13,9 +13,9 @@ import (
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
...@@ -29,10 +29,20 @@ type remoteCloser interface { ...@@ -29,10 +29,20 @@ type remoteCloser interface {
CloseRemote(protocol.ByteCount) CloseRemote(protocol.ByteCount)
} }
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections. // Server is a HTTP2 server listening for QUIC connections.
type Server struct { type Server struct {
*http.Server *http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use // Private flag for demo, do not use
CloseAfterFirstRequest bool CloseAfterFirstRequest bool
...@@ -69,11 +79,11 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { ...@@ -69,11 +79,11 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
} }
// Serve an existing UDP connection. // Serve an existing UDP connection.
func (s *Server) Serve(conn *net.UDPConn) error { func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn) return s.serveImpl(s.TLSConfig, conn)
} }
func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
...@@ -83,17 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { ...@@ -83,17 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
config := quic.Config{
TLSConfig: tlsConfig,
Versions: protocol.SupportedVersions,
}
var ln quic.Listener var ln quic.Listener
var err error var err error
if conn == nil { if conn == nil {
ln, err = quic.ListenAddr(s.Addr, &config) ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else { } else {
ln, err = quic.Listen(conn, &config) ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
} }
if err != nil { if err != nil {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
...@@ -255,7 +260,6 @@ func (s *Server) CloseGracefully(timeout time.Duration) error { ...@@ -255,7 +260,6 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error { func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port) port := atomic.LoadUint32(&s.port)
...@@ -283,7 +287,6 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error { ...@@ -283,7 +287,6 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
} }
} }
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil return nil
......
...@@ -5,9 +5,9 @@ import ( ...@@ -5,9 +5,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// ConnectionParametersManager negotiates and stores the connection parameters // ConnectionParametersManager negotiates and stores the connection parameters
...@@ -50,6 +50,8 @@ type connectionParametersManager struct { ...@@ -50,6 +50,8 @@ type connectionParametersManager struct {
sendConnectionFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
} }
var _ ConnectionParametersManager = &connectionParametersManager{} var _ ConnectionParametersManager = &connectionParametersManager{}
...@@ -61,14 +63,19 @@ var ( ...@@ -61,14 +63,19 @@ var (
) )
// NewConnectionParamatersManager creates a new connection parameters manager // NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber) ConnectionParametersManager { func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{ h := &connectionParametersManager{
perspective: pers, perspective: pers,
version: v, version: v,
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
if h.perspective == protocol.PerspectiveServer { if h.perspective == protocol.PerspectiveServer {
...@@ -207,10 +214,7 @@ func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protoc ...@@ -207,10 +214,7 @@ func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protoc
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveStreamFlowControlWindow
return protocol.MaxReceiveStreamFlowControlWindowServer
}
return protocol.MaxReceiveStreamFlowControlWindowClient
} }
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
...@@ -222,10 +226,7 @@ func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() pr ...@@ -222,10 +226,7 @@ func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() pr
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveConnectionFlowControlWindow
return protocol.MaxReceiveConnectionFlowControlWindowServer
}
return protocol.MaxReceiveConnectionFlowControlWindowClient
} }
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection // GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
......
...@@ -12,9 +12,9 @@ import ( ...@@ -12,9 +12,9 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type cryptoSetupClient struct { type cryptoSetupClient struct {
...@@ -332,7 +332,6 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu ...@@ -332,7 +332,6 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { } else if h.secureAEAD != nil {
...@@ -342,6 +341,10 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { ...@@ -342,6 +341,10 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
} }
} }
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
......
...@@ -10,9 +10,9 @@ import ( ...@@ -10,9 +10,9 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
...@@ -214,12 +214,16 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu ...@@ -214,12 +214,16 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
if h.forwardSecureAEAD != nil && h.sentSHLO {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { }
// secureAEAD and forwardSecureAEAD are created at the same time (when receiving the CHLO) return protocol.EncryptionUnencrypted, h.sealUnencrypted
// make sure that the SHLO isn't sent forward-secure }
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure return protocol.EncryptionSecure, h.sealSecure
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.sealUnencrypted
...@@ -251,7 +255,6 @@ func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protoc ...@@ -251,7 +255,6 @@ func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protoc
} }
func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
h.sentSHLO = true
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
} }
......
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
......
...@@ -7,9 +7,9 @@ import ( ...@@ -7,9 +7,9 @@ import (
"io" "io"
"sort" "sort"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A HandshakeMessage is a handshake message // A HandshakeMessage is a handshake message
......
...@@ -15,6 +15,7 @@ type CryptoSetup interface { ...@@ -15,6 +15,7 @@ type CryptoSetup interface {
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// TransportParameters are parameters sent to the peer during the handshake // TransportParameters are parameters sent to the peer during the handshake
......
...@@ -8,8 +8,8 @@ import ( ...@@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type serverConfigClient struct { type serverConfigClient struct {
......
package handshaketests
import (
"crypto/tls"
"fmt"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/proxy"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake integration tets", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
for {
_, _ = server.Accept()
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
expectedDuration := time.Duration(num) * rtt
Expect(testDuration).To(SatisfyAll(
BeNumerically(">=", expectedDuration),
BeNumerically("<", expectedDuration+rtt),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
Expect(len(protocol.SupportedVersions)).To(BeNumerically(">", 1))
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require an STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})
package quic package quic
import ( import (
"crypto/tls"
"io" "io"
"net" "net"
"time" "time"
...@@ -11,12 +10,32 @@ import ( ...@@ -11,12 +10,32 @@ import (
// Stream is the interface implemented by QUIC streams // Stream is the interface implemented by QUIC streams
type Stream interface { type Stream interface {
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
io.Reader io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
io.Writer io.Writer
io.Closer io.Closer
StreamID() protocol.StreamID StreamID() protocol.StreamID
// Reset closes the stream with an error. // Reset closes the stream with an error.
Reset(error) Reset(error)
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
} }
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
...@@ -37,6 +56,9 @@ type Session interface { ...@@ -37,6 +56,9 @@ type Session interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
} }
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. // A NonFWSession is a QUIC connection between two peers half-way through the handshake.
...@@ -61,21 +83,31 @@ type STK struct { ...@@ -61,21 +83,31 @@ type STK struct {
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct { type Config struct {
TLSConfig *tls.Config
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.
// If not set, it uses all versions available. // If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon. // Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header. // Ask the server to truncate the connection ID sent in the Public Header.
// If not set, the default checks if
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client. // Currently only valid for the client.
RequestConnectionIDTruncation bool RequestConnectionIDTruncation bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted. // AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK. // It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours // If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// This option is only valid for the server. // This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool AcceptSTK func(clientAddr net.Addr, stk *STK) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
} }
// A Listener for incoming QUIC connections // A Listener for incoming QUIC connections
......
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/handshake (interfaces: ConnectionParametersManager)
package mocks
import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/handshake"
protocol "github.com/lucas-clemente/quic-go/protocol"
time "time"
)
// Mock of ConnectionParametersManager interface
type MockConnectionParametersManager struct {
ctrl *gomock.Controller
recorder *_MockConnectionParametersManagerRecorder
}
// Recorder for MockConnectionParametersManager (not exported)
type _MockConnectionParametersManagerRecorder struct {
mock *MockConnectionParametersManager
}
func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager {
mock := &MockConnectionParametersManager{ctrl: ctrl}
mock.recorder = &_MockConnectionParametersManagerRecorder{mock}
return mock
}
func (_m *MockConnectionParametersManager) EXPECT() *_MockConnectionParametersManagerRecorder {
return _m.recorder
}
func (_m *MockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) {
ret := _m.ctrl.Call(_m, "GetHelloMap")
ret0, _ := ret[0].(map[handshake.Tag][]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockConnectionParametersManagerRecorder) GetHelloMap() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetHelloMap")
}
func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetIdleConnectionStateLifetime() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetIdleConnectionStateLifetime")
}
func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxIncomingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxIncomingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxOutgoingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxOutgoingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) SetFromMap(_param0 map[handshake.Tag][]byte) error {
ret := _m.ctrl.Call(_m, "SetFromMap", _param0)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) SetFromMap(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SetFromMap", arg0)
}
func (_m *MockConnectionParametersManager) TruncateConnectionID() bool {
ret := _m.ctrl.Call(_m, "TruncateConnectionID")
ret0, _ := ret[0].(bool)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) TruncateConnectionID() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "TruncateConnectionID")
}
package mocks
//go:generate mockgen -destination mocks_fc/flow_control_manager.go -package mocks_fc github.com/lucas-clemente/quic-go/flowcontrol FlowControlManager
//go:generate mockgen -destination cpm.go -package mocks github.com/lucas-clemente/quic-go/handshake ConnectionParametersManager
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/flowcontrol (interfaces: FlowControlManager)
package mocks_fc
import (
gomock "github.com/golang/mock/gomock"
flowcontrol "github.com/lucas-clemente/quic-go/flowcontrol"
protocol "github.com/lucas-clemente/quic-go/protocol"
)
// Mock of FlowControlManager interface
type MockFlowControlManager struct {
ctrl *gomock.Controller
recorder *_MockFlowControlManagerRecorder
}
// Recorder for MockFlowControlManager (not exported)
type _MockFlowControlManagerRecorder struct {
mock *MockFlowControlManager
}
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
mock := &MockFlowControlManager{ctrl: ctrl}
mock.recorder = &_MockFlowControlManagerRecorder{mock}
return mock
}
func (_m *MockFlowControlManager) EXPECT() *_MockFlowControlManagerRecorder {
return _m.recorder
}
func (_m *MockFlowControlManager) AddBytesRead(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesRead", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesRead", arg0, arg1)
}
func (_m *MockFlowControlManager) AddBytesSent(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesSent", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesSent", arg0, arg1)
}
func (_m *MockFlowControlManager) GetReceiveWindow(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "GetReceiveWindow", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveWindow", arg0)
}
func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
ret := _m.ctrl.Call(_m, "GetWindowUpdates")
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) GetWindowUpdates() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetWindowUpdates")
}
func (_m *MockFlowControlManager) NewStream(_param0 protocol.StreamID, _param1 bool) {
_m.ctrl.Call(_m, "NewStream", _param0, _param1)
}
func (_mr *_MockFlowControlManagerRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "NewStream", arg0, arg1)
}
func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) RemainingConnectionWindowSize() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemainingConnectionWindowSize")
}
func (_m *MockFlowControlManager) RemoveStream(_param0 protocol.StreamID) {
_m.ctrl.Call(_m, "RemoveStream", _param0)
}
func (_mr *_MockFlowControlManagerRecorder) RemoveStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemoveStream", arg0)
}
func (_m *MockFlowControlManager) ResetStream(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "ResetStream", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "ResetStream", arg0, arg1)
}
func (_m *MockFlowControlManager) SendWindowSize(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "SendWindowSize", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) SendWindowSize(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SendWindowSize", arg0)
}
func (_m *MockFlowControlManager) UpdateHighestReceived(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateHighestReceived", arg0, arg1)
}
func (_m *MockFlowControlManager) UpdateWindow(_param0 protocol.StreamID, _param1 protocol.ByteCount) (bool, error) {
ret := _m.ctrl.Call(_m, "UpdateWindow", _param0, _param1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateWindow", arg0, arg1)
}
package utils package utils
import ( import (
"fmt"
"log" "log"
"os" "os"
"strconv"
"time" "time"
) )
// LogLevel of quic-go // LogLevel of quic-go
type LogLevel uint8 type LogLevel uint8
const ( const logEnv = "QUIC_GO_LOG_LEVEL"
logEnv = "QUIC_GO_LOG_LEVEL"
// LogLevelDebug enables debug logs (e.g. packet contents) const (
LogLevelDebug LogLevel = iota // LogLevelNothing disables
// LogLevelInfo enables info logs (e.g. packets) LogLevelNothing LogLevel = iota
LogLevelInfo
// LogLevelError enables err logs // LogLevelError enables err logs
LogLevelError LogLevelError
// LogLevelNothing disables // LogLevelInfo enables info logs (e.g. packets)
LogLevelNothing LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
) )
var ( var (
...@@ -49,14 +49,14 @@ func Debugf(format string, args ...interface{}) { ...@@ -49,14 +49,14 @@ func Debugf(format string, args ...interface{}) {
// Infof logs something // Infof logs something
func Infof(format string, args ...interface{}) { func Infof(format string, args ...interface{}) {
if logLevel <= LogLevelInfo { if logLevel >= LogLevelInfo {
logMessage(format, args...) logMessage(format, args...)
} }
} }
// Errorf logs something // Errorf logs something
func Errorf(format string, args ...interface{}) { func Errorf(format string, args ...interface{}) {
if logLevel <= LogLevelError { if logLevel >= LogLevelError {
logMessage(format, args...) logMessage(format, args...)
} }
} }
...@@ -79,13 +79,16 @@ func init() { ...@@ -79,13 +79,16 @@ func init() {
} }
func readLoggingEnv() { func readLoggingEnv() {
env := os.Getenv(logEnv) switch os.Getenv(logEnv) {
if env == "" { case "":
return
}
level, err := strconv.Atoi(env)
if err != nil {
return return
case "DEBUG":
logLevel = LogLevelDebug
case "INFO":
logLevel = LogLevelInfo
case "ERROR":
logLevel = LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
} }
logLevel = LogLevel(level)
} }
package utils
import "time"
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(deadline.Sub(time.Now()))
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
type quicAEAD interface { type quicAEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
} }
......
...@@ -31,7 +31,7 @@ type StreamID uint32 ...@@ -31,7 +31,7 @@ type StreamID uint32
type ByteCount uint64 type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount // MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = math.MaxUint64 const MaxByteCount = ByteCount(math.MaxUint64)
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on // MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
......
...@@ -39,21 +39,21 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB ...@@ -39,21 +39,21 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using // This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using // This is the value that Chromium is using
...@@ -128,8 +128,8 @@ const MaxIdleTimeoutServer = 1 * time.Minute ...@@ -128,8 +128,8 @@ const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server // MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed // ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted // after this time all information about the old connection will be deleted
......
...@@ -4,9 +4,9 @@ import ( ...@@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment