Commit a674c005 authored by Matthew Holt's avatar Matthew Holt

vendor: Update quic and lego/acme dependencies

parent 98de336a
...@@ -240,12 +240,14 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -240,12 +240,14 @@ func (c *ACMEClient) Obtain(name string) error {
for attempts := 0; attempts < 2; attempts++ { for attempts := 0; attempts < 2; attempts++ {
namesObtaining.Add([]string{name}) namesObtaining.Add([]string{name})
acmeMu.Lock() acmeMu.Lock()
certificate, failures := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple) certificate, err := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
acmeMu.Unlock() acmeMu.Unlock()
namesObtaining.Remove([]string{name}) namesObtaining.Remove([]string{name})
if len(failures) > 0 { if err != nil {
// Error - try to fix it or report it to the user and abort // Error - try to fix it or report it to the user and abort
if failures, ok := err.(acme.ObtainError); ok && len(failures) > 0 {
// in this case, we can enumerate the error per-domain
var errMsg string // combine all the failures into a single error message var errMsg string // combine all the failures into a single error message
for errDomain, obtainErr := range failures { for errDomain, obtainErr := range failures {
if obtainErr == nil { if obtainErr == nil {
...@@ -253,13 +255,15 @@ func (c *ACMEClient) Obtain(name string) error { ...@@ -253,13 +255,15 @@ func (c *ACMEClient) Obtain(name string) error {
} }
errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr) errMsg += fmt.Sprintf("[%s] failed to get certificate: %v\n", errDomain, obtainErr)
} }
return errors.New(errMsg) return errors.New(errMsg)
} }
return fmt.Errorf("[%s] failed to obtain certificate: %v", name, err)
}
// double-check that we actually got a certificate; check a couple fields // double-check that we actually got a certificate; check a couple fields
// TODO: This is a temporary workaround for what I think is a bug in the acmev2 package (March 2018) // TODO: This is a temporary workaround for what I think is a bug in the acmev2 package (March 2018)
// but it might not hurt to keep this extra check in place // but it might not hurt to keep this extra check in place (April 18, 2018: might be fixed now.)
if certificate.Domain == "" || certificate.Certificate == nil { if certificate.Domain == "" || certificate.Certificate == nil {
return errors.New("returned certificate was empty; probably an unchecked error obtaining it") return errors.New("returned certificate was empty; probably an unchecked error obtaining it")
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !amd64,!s390x // +build !amd64
package aes12 package aes12
......
...@@ -38,6 +38,8 @@ type client struct { ...@@ -38,6 +38,8 @@ type client struct {
version protocol.VersionNumber version protocol.VersionNumber
session packetHandler session packetHandler
logger utils.Logger
} }
var ( var (
...@@ -102,9 +104,10 @@ func Dial( ...@@ -102,9 +104,10 @@ func Dial(
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}), versionNegotiationChan: make(chan struct{}),
logger: utils.DefaultLogger,
} }
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil { if err := c.dial(); err != nil {
return nil, err return nil, err
...@@ -197,7 +200,7 @@ func (c *client) dialTLS() error { ...@@ -197,7 +200,7 @@ func (c *client) dialTLS() error {
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
} }
csc := handshake.NewCryptoStreamConn(nil) csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version) extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil { if err != nil {
return err return err
...@@ -214,7 +217,7 @@ func (c *client) dialTLS() error { ...@@ -214,7 +217,7 @@ func (c *client) dialTLS() error {
if err != handshake.ErrCloseSessionForRetry { if err != handshake.ErrCloseSessionForRetry {
return err return err
} }
utils.Infof("Received a Retry packet. Recreating session.") c.logger.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err return err
} }
...@@ -237,7 +240,7 @@ func (c *client) establishSecureConnection() error { ...@@ -237,7 +240,7 @@ func (c *client) establishSecureConnection() error {
go func() { go func() {
runErr = c.session.run() // returns as soon as the session is closed runErr = c.session.run() // returns as soon as the session is closed
close(errorChan) close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID) c.logger.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion { if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close() c.conn.Close()
} }
...@@ -291,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { ...@@ -291,7 +294,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r, c.version) hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil { if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) c.logger.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the header // drop this packet if we can't parse the header
return return
} }
...@@ -314,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { ...@@ -314,15 +317,15 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
// check if the remote address and the connection ID match // check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.") c.logger.Infof("Received a spoofed Public Reset. Ignoring.")
return return
} }
pr, err := wire.ParsePublicReset(r) pr, err := wire.ParsePublicReset(r)
if err != nil { if err != nil {
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err) c.logger.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return return
} }
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber) c.logger.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber))) c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return return
} }
...@@ -368,7 +371,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { ...@@ -368,7 +371,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
} }
} }
utils.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
...@@ -385,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { ...@@ -385,7 +388,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
if err != nil { if err != nil {
return err return err
} }
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) c.logger.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion) c.session.Close(errCloseSessionForNewVersion)
return nil return nil
} }
...@@ -402,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) { ...@@ -402,6 +405,7 @@ func (c *client) createNewGQUICSession() (err error) {
c.config, c.config,
c.initialVersion, c.initialVersion,
c.negotiatedVersions, c.negotiatedVersions,
c.logger,
) )
return err return err
} }
...@@ -421,6 +425,7 @@ func (c *client) createNewTLSSession( ...@@ -421,6 +425,7 @@ func (c *client) createNewTLSSession(
c.tls, c.tls,
paramsChan, paramsChan,
1, 1,
c.logger,
) )
return err return err
} }
...@@ -19,12 +19,14 @@ func main() { ...@@ -19,12 +19,14 @@ func main() {
flag.Parse() flag.Parse()
urls := flag.Args() urls := flag.Args()
logger := utils.DefaultLogger
if *verbose { if *verbose {
utils.SetLogLevel(utils.LogLevelDebug) logger.SetLogLevel(utils.LogLevelDebug)
} else { } else {
utils.SetLogLevel(utils.LogLevelInfo) logger.SetLogLevel(utils.LogLevelInfo)
} }
utils.SetLogTimeFormat("") logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions versions := protocol.SupportedVersions
if *tls { if *tls {
...@@ -42,21 +44,21 @@ func main() { ...@@ -42,21 +44,21 @@ func main() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(urls)) wg.Add(len(urls))
for _, addr := range urls { for _, addr := range urls {
utils.Infof("GET %s", addr) logger.Infof("GET %s", addr)
go func(addr string) { go func(addr string) {
rsp, err := hclient.Get(addr) rsp, err := hclient.Get(addr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
utils.Infof("Got response for %s: %#v", addr, rsp) logger.Infof("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{} body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body) _, err = io.Copy(body, rsp.Body)
if err != nil { if err != nil {
panic(err) panic(err)
} }
utils.Infof("Request Body:") logger.Infof("Request Body:")
utils.Infof("%s", body.Bytes()) logger.Infof("%s", body.Bytes())
wg.Done() wg.Done()
}(addr) }(addr)
} }
......
...@@ -91,7 +91,7 @@ func init() { ...@@ -91,7 +91,7 @@ func init() {
} }
} }
if err != nil { if err != nil {
utils.Infof("Error receiving upload: %#v", err) utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
} }
} }
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data"> io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
...@@ -126,12 +126,14 @@ func main() { ...@@ -126,12 +126,14 @@ func main() {
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)") tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse() flag.Parse()
logger := utils.DefaultLogger
if *verbose { if *verbose {
utils.SetLogLevel(utils.LogLevelDebug) logger.SetLogLevel(utils.LogLevelDebug)
} else { } else {
utils.SetLogLevel(utils.LogLevelInfo) logger.SetLogLevel(utils.LogLevelInfo)
} }
utils.SetLogTimeFormat("") logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions versions := protocol.SupportedVersions
if *tls { if *tls {
......
...@@ -46,6 +46,8 @@ type client struct { ...@@ -46,6 +46,8 @@ type client struct {
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
} }
var _ http.RoundTripper = &client{} var _ http.RoundTripper = &client{}
...@@ -75,6 +77,7 @@ func newClient( ...@@ -75,6 +77,7 @@ func newClient(
opts: opts, opts: opts,
headerErrored: make(chan struct{}), headerErrored: make(chan struct{}),
dialer: dialer, dialer: dialer,
logger: utils.DefaultLogger,
} }
} }
...@@ -95,7 +98,7 @@ func (c *client) dial() error { ...@@ -95,7 +98,7 @@ func (c *client) dial() error {
if err != nil { if err != nil {
return err return err
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream() go c.handleHeaderStream()
return nil return nil
} }
...@@ -109,7 +112,7 @@ func (c *client) handleHeaderStream() { ...@@ -109,7 +112,7 @@ func (c *client) handleHeaderStream() {
err = c.readResponse(h2framer, decoder) err = c.readResponse(h2framer, decoder)
} }
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway { if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
utils.Debugf("Error handling header stream: %s", err) c.logger.Debugf("Error handling header stream: %s", err)
} }
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request // stop all running request
......
...@@ -23,13 +23,16 @@ type requestWriter struct { ...@@ -23,13 +23,16 @@ type requestWriter struct {
henc *hpack.Encoder henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
} }
const defaultUserAgent = "quic-go" const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream) *requestWriter { func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{ rw := &requestWriter{
headerStream: headerStream, headerStream: headerStream,
logger: logger,
} }
rw.henc = hpack.NewEncoder(&rw.hbuf) rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw return rw
...@@ -156,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra ...@@ -156,7 +159,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra
} }
func (w *requestWriter) writeHeader(name, value string) { func (w *requestWriter) writeHeader(name, value string) {
utils.Debugf("http2: Transport encoding header %q = %q", name, value) w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
} }
......
...@@ -24,15 +24,24 @@ type responseWriter struct { ...@@ -24,15 +24,24 @@ type responseWriter struct {
header http.Header header http.Header
status int // status code passed to WriteHeader status int // status code passed to WriteHeader
headerWritten bool headerWritten bool
logger utils.Logger
} }
func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID protocol.StreamID) *responseWriter { func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{ return &responseWriter{
header: http.Header{}, header: http.Header{},
headerStream: headerStream, headerStream: headerStream,
headerStreamMutex: headerStreamMutex, headerStreamMutex: headerStreamMutex,
dataStream: dataStream, dataStream: dataStream,
dataStreamID: dataStreamID, dataStreamID: dataStreamID,
logger: logger,
} }
} }
...@@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) { ...@@ -57,7 +66,7 @@ func (w *responseWriter) WriteHeader(status int) {
} }
} }
utils.Infof("Responding with %d", status) w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock() w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock() defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil) h2framer := http2.NewFramer(w.headerStream, nil)
...@@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) { ...@@ -67,7 +76,7 @@ func (w *responseWriter) WriteHeader(status int) {
BlockFragment: headers.Bytes(), BlockFragment: headers.Bytes(),
}) })
if err != nil { if err != nil {
utils.Errorf("could not write h2 header: %s", err.Error()) w.logger.Errorf("could not write h2 header: %s", err.Error())
} }
} }
......
...@@ -53,6 +53,8 @@ type Server struct { ...@@ -53,6 +53,8 @@ type Server struct {
closed bool closed bool
supportedVersionsAsString string supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
} }
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
...@@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { ...@@ -88,6 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
s.logger = utils.DefaultLogger
s.listenerMutex.Lock() s.listenerMutex.Lock()
if s.closed { if s.closed {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
...@@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) { ...@@ -138,7 +141,7 @@ func (s *Server) handleHeaderStream(session streamCreator) {
// In this case, the session has already logged the error, so we don't // In this case, the session has already logged the error, so we don't
// need to log it again. // need to log it again.
if _, ok := err.(*qerr.QuicError); !ok { if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error()) s.logger.Errorf("error handling h2 request: %s", err.Error())
} }
session.Close(err) session.Close(err)
return return
...@@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -160,7 +163,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
} }
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil { if err != nil {
utils.Errorf("invalid http2 headers encoding: %s", err.Error()) s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err return err
} }
...@@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -169,10 +172,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return err return err
} }
if utils.Debug() { if s.logger.Debug() {
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else { } else {
utils.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
} }
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID)) dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
...@@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -201,7 +204,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
req.RemoteAddr = session.RemoteAddr().String() req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler handler := s.Handler
if handler == nil { if handler == nil {
...@@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, ...@@ -215,7 +218,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
const size = 64 << 10 const size = 64 << 10
buf := make([]byte, size) buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)] buf = buf[:runtime.Stack(buf, false)]
utils.Errorf("http: panic serving: %v\n%s", p, buf) s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true panicked = true
} }
}() }()
......
...@@ -95,6 +95,8 @@ type QuicProxy struct { ...@@ -95,6 +95,8 @@ type QuicProxy struct {
// Mapping from client addresses (as host:port) to connection // Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection clientDict map[string]*connection
logger utils.Logger
} }
// NewQuicProxy creates a new UDP proxy // NewQuicProxy creates a new UDP proxy
...@@ -132,9 +134,10 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu ...@@ -132,9 +134,10 @@ func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*Qu
dropPacket: packetDropper, dropPacket: packetDropper,
delayPacket: packetDelayer, delayPacket: packetDelayer,
version: version, version: version,
logger: utils.DefaultLogger,
} }
utils.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy() go p.runProxy()
return &p, nil return &p, nil
} }
...@@ -200,8 +203,8 @@ func (p *QuicProxy) runProxy() error { ...@@ -200,8 +203,8 @@ func (p *QuicProxy) runProxy() error {
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) { if p.dropPacket(DirectionIncoming, packetCount) {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n) p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
} }
continue continue
} }
...@@ -209,16 +212,16 @@ func (p *QuicProxy) runProxy() error { ...@@ -209,16 +212,16 @@ func (p *QuicProxy) runProxy() error {
// Send the packet to the server // Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount) delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 { if delay != 0 {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay) p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
} }
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = conn.ServerConn.Write(raw) _, _ = conn.ServerConn.Write(raw)
}) })
} else { } else {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr()) p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
} }
if _, err := conn.ServerConn.Write(raw); err != nil { if _, err := conn.ServerConn.Write(raw); err != nil {
return err return err
...@@ -240,24 +243,24 @@ func (p *QuicProxy) runConnection(conn *connection) error { ...@@ -240,24 +243,24 @@ func (p *QuicProxy) runConnection(conn *connection) error {
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1) packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) { if p.dropPacket(DirectionOutgoing, packetCount) {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n) p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
} }
continue continue
} }
delay := p.delayPacket(DirectionOutgoing, packetCount) delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 { if delay != 0 {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay) p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
} }
time.AfterFunc(delay, func() { time.AfterFunc(delay, func() {
// TODO: handle error // TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr) _, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
}) })
} else { } else {
if utils.Debug() { if p.logger.Debug() {
utils.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr) p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
} }
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err return err
......
...@@ -30,7 +30,7 @@ var _ = BeforeEach(func() { ...@@ -30,7 +30,7 @@ var _ = BeforeEach(func() {
logFile, err = os.Create(logFileName) logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile) log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug) utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
} }
}) })
......
...@@ -16,6 +16,9 @@ type StreamID = protocol.StreamID ...@@ -16,6 +16,9 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number. // A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber type VersionNumber = protocol.VersionNumber
// VersionGQUIC39 is gQUIC version 39.
const VersionGQUIC39 = protocol.Version39
// A Cookie can be used to verify the ownership of the client address. // A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie type Cookie = handshake.Cookie
...@@ -118,6 +121,8 @@ type Session interface { ...@@ -118,6 +121,8 @@ type Session interface {
AcceptUniStream() (ReceiveStream, error) AcceptUniStream() (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream. // OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached. // It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error // TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error) OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream. // OpenStreamSync opens a new bidirectional QUIC stream.
......
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
...@@ -8,35 +8,22 @@ import ( ...@@ -8,35 +8,22 @@ import (
) )
// A Packet is a packet // A Packet is a packet
// +gen linkedlist
type Packet struct { type Packet struct {
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
PacketType protocol.PacketType PacketType protocol.PacketType
Frames []wire.Frame Frames []wire.Frame
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
queuedForRetransmission bool // There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber retransmissionOf protocol.PacketNumber
} }
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
var fs []wire.Frame
for _, frame := range p.Frames {
switch frame.(type) {
case *wire.AckFrame:
continue
case *wire.StopWaitingFrame:
continue
}
fs = append(fs, frame)
}
return fs
}
// Generated by: main // This file was automatically generated by genny.
// TypeWriter: linkedlist // Any changes will be lost if this file is regenerated.
// Directive: +gen on Packet // see https://github.com/cheekybits/genny
package ackhandler package ackhandler
// List is a modification of http://golang.org/pkg/container/list/ // Linked list implementation from the Go standard library.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// PacketElement is an element of a linked list. // PacketElement is an element of a linked list.
type PacketElement struct { type PacketElement struct {
...@@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement { ...@@ -41,8 +38,7 @@ func (e *PacketElement) Prev() *PacketElement {
return nil return nil
} }
// PacketList represents a doubly linked list. // PacketList is a linked list of Packets.
// The zero value for PacketList is an empty list ready to use.
type PacketList struct { type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element len int // current list length excluding (this) sentinel element
...@@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() } ...@@ -63,7 +59,7 @@ func NewPacketList() *PacketList { return new(PacketList).Init() }
// The complexity is O(1). // The complexity is O(1).
func (l *PacketList) Len() int { return l.len } func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement { func (l *PacketList) Front() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement { ...@@ -71,7 +67,7 @@ func (l *PacketList) Front() *PacketElement {
return l.root.next return l.root.next
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement { func (l *PacketList) Back() *PacketElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement { ...@@ -79,7 +75,7 @@ func (l *PacketList) Back() *PacketElement {
return l.root.prev return l.root.prev
} }
// lazyInit lazily initializes a zero PacketList value. // lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() { func (l *PacketList) lazyInit() {
if l.root.next == nil { if l.root.next == nil {
l.Init() l.Init()
...@@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement { ...@@ -98,7 +94,7 @@ func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
return e return e
} }
// insertValue is a convenience wrapper for insert(&PacketElement{Value: v}, at). // insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at) return l.insert(&PacketElement{Value: v}, at)
} }
...@@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement { ...@@ -116,10 +112,11 @@ func (l *PacketList) remove(e *PacketElement) *PacketElement {
// Remove removes e from l if e is an element of list l. // Remove removes e from l if e is an element of list l.
// It returns the element value e.Value. // It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet { func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l { if e.list == l {
// if e.list == l, l must have been initialized when e was inserted // if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketElement) and l.remove will crash // in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e) l.remove(e)
} }
return e.Value return e.Value
...@@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement { ...@@ -139,46 +136,51 @@ func (l *PacketList) PushBack(v Packet) *PacketElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e. // InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev) return l.insertValue(v, mark.prev)
} }
// InsertAfter inserts a new element e with value v immediately after mark and returns e. // InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark) return l.insertValue(v, mark)
} }
// MoveToFront moves element e to the front of list l. // MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) { func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e { if e.list != l || l.root.next == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root) l.insert(l.remove(e), &l.root)
} }
// MoveToBack moves element e to the back of list l. // MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) { func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e { if e.list != l || l.root.prev == e {
return return
} }
// see comment in PacketList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark. // MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) { func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) { ...@@ -187,7 +189,8 @@ func (l *PacketList) MoveBefore(e, mark *PacketElement) {
} }
// MoveAfter moves element e to its new position after mark. // MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) { func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) { ...@@ -196,7 +199,7 @@ func (l *PacketList) MoveAfter(e, mark *PacketElement) {
} }
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) { func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
...@@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) { ...@@ -205,7 +208,7 @@ func (l *PacketList) PushBackList(other *PacketList) {
} }
// PushFrontList inserts a copy of an other list at the front of list l. // PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) { func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
......
...@@ -12,6 +12,8 @@ const ( ...@@ -12,6 +12,8 @@ const (
SendAck SendAck
// SendRetransmission means that retransmissions should be sent // SendRetransmission means that retransmissions should be sent
SendRetransmission SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendAny packet should be sent // SendAny packet should be sent
SendAny SendAny
) )
...@@ -24,6 +26,8 @@ func (s SendMode) String() string { ...@@ -24,6 +26,8 @@ func (s SendMode) String() string {
return "ack" return "ack"
case SendRetransmission: case SendRetransmission:
return "retransmission" return "retransmission"
case SendRTO:
return "rto"
case SendAny: case SendAny:
return "any" return "any"
default: default:
......
...@@ -87,26 +87,23 @@ func (h *sentPacketHistory) FirstOutstanding() *Packet { ...@@ -87,26 +87,23 @@ func (h *sentPacketHistory) FirstOutstanding() *Packet {
// QueuePacketForRetransmission marks a packet for retransmission. // QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once. // A packet can only be queued once.
func (h *sentPacketHistory) QueuePacketForRetransmission(pn protocol.PacketNumber) (*Packet, error) { func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn] el, ok := h.packetMap[pn]
if !ok { if !ok {
return nil, fmt.Errorf("sent packet history: packet %d not found", pn) return fmt.Errorf("sent packet history: packet %d not found", pn)
} }
if el.Value.queuedForRetransmission { el.Value.canBeRetransmitted = false
return nil, fmt.Errorf("sent packet history BUG: packet %d already queued for retransmission", pn)
}
el.Value.queuedForRetransmission = true
if el == h.firstOutstanding { if el == h.firstOutstanding {
h.readjustFirstOutstanding() h.readjustFirstOutstanding()
} }
return &el.Value, nil return nil
} }
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet. // readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted. // This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() { func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next() el := h.firstOutstanding.Next()
for el != nil && el.Value.queuedForRetransmission { for el != nil && !el.Value.canBeRetransmitted {
el = el.Next() el = el.Next()
} }
h.firstOutstanding = el h.firstOutstanding = el
......
...@@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() { ...@@ -292,11 +292,3 @@ func (c *cubicSender) OnConnectionMigration() {
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) { func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled c.slowStartLargeReduction = enabled
} }
// RetransmissionDelay gives the time to retransmission
func (c *cubicSender) RetransmissionDelay() time.Duration {
if c.rttStats.SmoothedRTT() == 0 {
return 0
}
return c.rttStats.SmoothedRTT() + c.rttStats.MeanDeviation()*4
}
...@@ -17,7 +17,6 @@ type SendAlgorithm interface { ...@@ -17,7 +17,6 @@ type SendAlgorithm interface {
SetNumEmulatedConnections(n int) SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool) OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration() OnConnectionMigration()
RetransmissionDelay() time.Duration
// Experiments // Experiments
SetSlowStartLargeReduction(enabled bool) SetSlowStartLargeReduction(enabled bool)
......
...@@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) { ...@@ -84,7 +84,6 @@ func (r *RTTStats) SetRecentMinRTTwindow(recentMinRTTwindow time.Duration) {
// UpdateRTT updates the RTT based on a new sample. // UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 { if sendDelta == utils.InfDuration || sendDelta <= 0 {
utils.Debugf("Ignoring measured sendDelta, because it's is either infinite, zero, or negative: %d", sendDelta/time.Microsecond)
return return
} }
......
...@@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) { ...@@ -21,10 +21,6 @@ func NewCurve25519KEX() (KeyExchange, error) {
if _, err := rand.Read(c.secret[:]); err != nil { if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key") return nil, errors.New("Curve25519: could not create private key")
} }
// See https://cr.yp.to/ecdh.html
c.secret[0] &= 248
c.secret[31] &= 127
c.secret[31] |= 64
curve25519.ScalarBaseMult(&c.public, &c.secret) curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil return c, nil
} }
......
...@@ -20,9 +20,7 @@ type TLSExporter interface { ...@@ -20,9 +20,7 @@ type TLSExporter interface {
} }
func qhkdfExpand(secret []byte, label string, length int) []byte { func qhkdfExpand(secret []byte, label string, length int) []byte {
// The last byte should be 0x0. qlabel := make([]byte, 2+1+5+len(label))
// Since Go initializes the slice to 0, we don't need to set it explicitly.
qlabel := make([]byte, 2+1+5+len(label)+1)
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length)) binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label)) qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label)) copy(qlabel[3:], []byte("QUIC "+label))
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39} var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) { func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID) clientSecret, serverSecret := computeSecrets(connectionID)
......
package crypto package crypto
import ( import (
"encoding/binary" "bytes"
"errors" "errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/fnv128a"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
...@@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb ...@@ -21,7 +22,7 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
} }
hash := fnv128a.New() hash := fnv.New128a()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src[12:]) hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer { if n.perspective == protocol.PerspectiveServer {
...@@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb ...@@ -29,13 +30,13 @@ func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumb
} else { } else {
hash.Write([]byte("Server")) hash.Write([]byte("Server"))
} }
testHigh, testLow := hash.Sum128() sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
low := binary.LittleEndian.Uint64(src) if !bytes.Equal(sum[:12], src[:12]) {
high := binary.LittleEndian.Uint32(src[8:]) return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
if uint32(testHigh&0xffffffff) != high || testLow != low {
return nil, errors.New("NullAEAD: failed to authenticate received data")
} }
return src[12:], nil return src[12:], nil
} }
...@@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb ...@@ -48,7 +49,7 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
dst = dst[:12+len(src)] dst = dst[:12+len(src)]
} }
hash := fnv128a.New() hash := fnv.New128a()
hash.Write(associatedData) hash.Write(associatedData)
hash.Write(src) hash.Write(src)
...@@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb ...@@ -57,15 +58,22 @@ func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumb
} else { } else {
hash.Write([]byte("Client")) hash.Write([]byte("Client"))
} }
sum := make([]byte, 0, 16)
high, low := hash.Sum128() sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src) copy(dst[12:], src)
binary.LittleEndian.PutUint64(dst, low) copy(dst, sum[:12])
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
return dst return dst
} }
func (n *nullAEADFNV128a) Overhead() int { func (n *nullAEADFNV128a) Overhead() int {
return 12 return 12
} }
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}
...@@ -25,6 +25,8 @@ type baseFlowController struct { ...@@ -25,6 +25,8 @@ type baseFlowController struct {
epochStartTime time.Time epochStartTime time.Time
epochStartOffset protocol.ByteCount epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
logger utils.Logger
} }
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
......
...@@ -22,6 +22,7 @@ func NewConnectionFlowController( ...@@ -22,6 +22,7 @@ func NewConnectionFlowController(
receiveWindow protocol.ByteCount, receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController { ) ConnectionFlowController {
return &connectionFlowController{ return &connectionFlowController{
baseFlowController: baseFlowController{ baseFlowController: baseFlowController{
...@@ -29,6 +30,7 @@ func NewConnectionFlowController( ...@@ -29,6 +30,7 @@ func NewConnectionFlowController(
receiveWindow: receiveWindow, receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
}, },
} }
} }
...@@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { ...@@ -65,7 +67,7 @@ func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize { if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
} }
c.mutex.Unlock() c.mutex.Unlock()
return offset return offset
......
...@@ -31,6 +31,7 @@ func NewStreamFlowController( ...@@ -31,6 +31,7 @@ func NewStreamFlowController(
maxReceiveWindow protocol.ByteCount, maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount, initialSendWindow protocol.ByteCount,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController { ) StreamFlowController {
return &streamFlowController{ return &streamFlowController{
streamID: streamID, streamID: streamID,
...@@ -42,6 +43,7 @@ func NewStreamFlowController( ...@@ -42,6 +43,7 @@ func NewStreamFlowController(
receiveWindowSize: receiveWindow, receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow, maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow, sendWindow: initialSendWindow,
logger: logger,
}, },
} }
} }
...@@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { ...@@ -137,7 +139,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
oldWindowSize := c.receiveWindowSize oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate() offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection { if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
} }
......
...@@ -12,14 +12,15 @@ import ( ...@@ -12,14 +12,15 @@ import (
// By including the cookie in its ClientHello, a client can proof ownership of its source address. // By including the cookie in its ClientHello, a client can proof ownership of its source address.
type CookieHandler struct { type CookieHandler struct {
callback func(net.Addr, *Cookie) bool callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator cookieGenerator *CookieGenerator
logger utils.Logger
} }
var _ mint.CookieHandler = &CookieHandler{} var _ mint.CookieHandler = &CookieHandler{}
// NewCookieHandler creates a new CookieHandler. // NewCookieHandler creates a new CookieHandler.
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) { func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator() cookieGenerator, err := NewCookieGenerator()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -27,6 +28,7 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er ...@@ -27,6 +28,7 @@ func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, er
return &CookieHandler{ return &CookieHandler{
callback: callback, callback: callback,
cookieGenerator: cookieGenerator, cookieGenerator: cookieGenerator,
logger: logger,
}, nil }, nil
} }
...@@ -42,7 +44,7 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) { ...@@ -42,7 +44,7 @@ func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool { func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token) data, err := h.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
return false return false
} }
return h.callback(conn.RemoteAddr(), data) return h.callback(conn.RemoteAddr(), data)
......
...@@ -38,13 +38,12 @@ type cryptoSetupClient struct { ...@@ -38,13 +38,12 @@ type cryptoSetupClient struct {
lastSentCHLO []byte lastSentCHLO []byte
certManager crypto.CertManager certManager crypto.CertManager
divNonceChan chan []byte divNonceChan <-chan []byte
diversificationNonce []byte diversificationNonce []byte
clientHelloCounter int clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
receivedSecurePacket bool receivedSecurePacket bool
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
...@@ -55,6 +54,8 @@ type cryptoSetupClient struct { ...@@ -55,6 +54,8 @@ type cryptoSetupClient struct {
handshakeEvent chan<- struct{} handshakeEvent chan<- struct{}
params *TransportParameters params *TransportParameters
logger utils.Logger
} }
var _ CryptoSetup = &cryptoSetupClient{} var _ CryptoSetup = &cryptoSetupClient{}
...@@ -77,12 +78,14 @@ func NewCryptoSetupClient( ...@@ -77,12 +78,14 @@ func NewCryptoSetupClient(
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { logger utils.Logger,
) (CryptoSetup, chan<- []byte, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return &cryptoSetupClient{ divNonceChan := make(chan []byte)
cs := &cryptoSetupClient{
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
hostname: hostname, hostname: hostname,
connID: connID, connID: connID,
...@@ -90,14 +93,15 @@ func NewCryptoSetupClient( ...@@ -90,14 +93,15 @@ func NewCryptoSetupClient(
certManager: crypto.NewCertManager(tlsConfig), certManager: crypto.NewCertManager(tlsConfig),
params: params, params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
paramsChan: paramsChan, paramsChan: paramsChan,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
initialVersion: initialVersion, initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte), divNonceChan: divNonceChan,
}, nil logger: logger,
}
return cs, divNonceChan, nil
} }
func (h *cryptoSetupClient) HandleCryptoStream() error { func (h *cryptoSetupClient) HandleCryptoStream() error {
...@@ -146,7 +150,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { ...@@ -146,7 +150,7 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
return err return err
} }
utils.Debugf("Got %s", message) h.logger.Debugf("Got %s", message)
switch message.Tag { switch message.Tag {
case TagREJ: case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil { if err := h.handleREJMessage(message.Data); err != nil {
...@@ -211,7 +215,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { ...@@ -211,7 +215,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
err = h.certManager.Verify(h.hostname) err = h.certManager.Verify(h.hostname)
if err != nil { if err != nil {
utils.Infof("Certificate validation failed: %s", err.Error()) h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid return qerr.ProofInvalid
} }
} }
...@@ -219,7 +223,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { ...@@ -219,7 +223,7 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil { if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get()) validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof { if !validProof {
utils.Infof("Server proof verification failed") h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid return qerr.ProofInvalid
} }
...@@ -373,14 +377,6 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry ...@@ -373,14 +377,6 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
return nil, errors.New("CryptoSetupClient: no encryption level specified") return nil, errors.New("CryptoSetupClient: no encryption level specified")
} }
func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient")
}
func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState { func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
...@@ -408,7 +404,7 @@ func (h *cryptoSetupClient) sendCHLO() error { ...@@ -408,7 +404,7 @@ func (h *cryptoSetupClient) sendCHLO() error {
Data: tags, Data: tags,
} }
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
message.Write(b) message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes()) _, err = h.cryptoStream.Write(b.Bytes())
......
...@@ -19,7 +19,7 @@ import ( ...@@ -19,7 +19,7 @@ import (
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX // KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() crypto.KeyExchange type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session // The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct { type cryptoSetupServer struct {
...@@ -54,6 +54,8 @@ type cryptoSetupServer struct { ...@@ -54,6 +54,8 @@ type cryptoSetupServer struct {
params *TransportParameters params *TransportParameters
sni string // need to fill out the ConnectionState sni string // need to fill out the ConnectionState
logger utils.Logger
} }
var _ CryptoSetup = &cryptoSetupServer{} var _ CryptoSetup = &cryptoSetupServer{}
...@@ -73,12 +75,14 @@ func NewCryptoSetup( ...@@ -73,12 +75,14 @@ func NewCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig, scfg *ServerConfig,
params *TransportParameters, params *TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil { if err != nil {
...@@ -90,6 +94,7 @@ func NewCryptoSetup( ...@@ -90,6 +94,7 @@ func NewCryptoSetup(
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
version: version, version: version,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg, scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
...@@ -99,6 +104,7 @@ func NewCryptoSetup( ...@@ -99,6 +104,7 @@ func NewCryptoSetup(
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
paramsChan: paramsChan, paramsChan: paramsChan,
handshakeEvent: handshakeEvent, handshakeEvent: handshakeEvent,
logger: logger,
}, nil }, nil
} }
...@@ -114,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error { ...@@ -114,7 +120,7 @@ func (h *cryptoSetupServer) HandleCryptoStream() error {
return qerr.InvalidCryptoMessageType return qerr.InvalidCryptoMessageType
} }
utils.Debugf("Got %s", message) h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data) done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil { if err != nil {
return err return err
...@@ -297,7 +303,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt ...@@ -297,7 +303,7 @@ func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byt
func (h *cryptoSetupServer) acceptSTK(token []byte) bool { func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token) stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("STK invalid: %s", err.Error()) h.logger.Debugf("STK invalid: %s", err.Error())
return false return false
} }
return h.acceptSTKCallback(h.remoteAddr, stk) return h.acceptSTKCallback(h.remoteAddr, stk)
...@@ -340,7 +346,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa ...@@ -340,7 +346,7 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa
var serverReply bytes.Buffer var serverReply bytes.Buffer
message.Write(&serverReply) message.Write(&serverReply)
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil return serverReply.Bytes(), nil
} }
...@@ -364,11 +370,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T ...@@ -364,11 +370,6 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err return nil, err
} }
h.diversificationNonce = make([]byte, 32)
if _, err = rand.Read(h.diversificationNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC] clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce) err = h.validateClientNonce(clientNonce)
if err != nil { if err != nil {
...@@ -405,7 +406,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T ...@@ -405,7 +406,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
var fsNonce bytes.Buffer var fsNonce bytes.Buffer
fsNonce.Write(clientNonce) fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce) fsNonce.Write(serverNonce)
ephermalKex := h.keyExchange() ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -443,19 +447,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T ...@@ -443,19 +447,10 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
} }
var reply bytes.Buffer var reply bytes.Buffer
message.Write(&reply) message.Write(&reply)
utils.Debugf("Sending %s", message) h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil return reply.Bytes(), nil
} }
// DiversificationNonce returns the diversification nonce
func (h *cryptoSetupServer) DiversificationNonce() []byte {
return h.diversificationNonce
}
func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState { func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
......
...@@ -31,6 +31,8 @@ type cryptoSetupTLS struct { ...@@ -31,6 +31,8 @@ type cryptoSetupTLS struct {
handshakeEvent chan<- struct{} handshakeEvent chan<- struct{}
} }
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer( func NewCryptoSetupTLSServer(
tls MintTLS, tls MintTLS,
...@@ -38,7 +40,7 @@ func NewCryptoSetupTLSServer( ...@@ -38,7 +40,7 @@ func NewCryptoSetupTLSServer(
nullAEAD crypto.AEAD, nullAEAD crypto.AEAD,
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
version protocol.VersionNumber, version protocol.VersionNumber,
) CryptoSetup { ) CryptoSetupTLS {
return &cryptoSetupTLS{ return &cryptoSetupTLS{
tls: tls, tls: tls,
cryptoStream: cryptoStream, cryptoStream: cryptoStream,
...@@ -57,7 +59,7 @@ func NewCryptoSetupTLSClient( ...@@ -57,7 +59,7 @@ func NewCryptoSetupTLSClient(
handshakeEvent chan<- struct{}, handshakeEvent chan<- struct{},
tls MintTLS, tls MintTLS,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -107,22 +109,18 @@ handshakeLoop: ...@@ -107,22 +109,18 @@ handshakeLoop:
return nil return nil
} }
func (h *cryptoSetupTLS) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.aead != nil { if h.aead == nil {
data, err := h.aead.Open(dst, src, packetNumber, associatedData) return nil, errors.New("no 1-RTT sealer")
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return data, protocol.EncryptionForwardSecure, nil
}
data, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
} }
return data, protocol.EncryptionUnencrypted, nil return h.aead.Open(dst, src, packetNumber, associatedData)
} }
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
...@@ -157,14 +155,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S ...@@ -157,14 +155,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) SetDiversificationNonce([]byte) {
panic("diversification nonce not needed for TLS")
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState { func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
var ( var (
...@@ -24,13 +23,13 @@ var ( ...@@ -24,13 +23,13 @@ var (
// used for all connections for 60 seconds is negligible. Thus we can amortise // used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a // the Diffie-Hellman key generation at the server over all the connections in a
// small time span. // small time span.
func getEphermalKEX() (res crypto.KeyExchange) { func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock() kexMutex.RLock()
res = kexCurrent res := kexCurrent
t := kexCurrentTime t := kexCurrentTime
kexMutex.RUnlock() kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime { if res != nil && time.Since(t) < kexLifetime {
return res return res, nil
} }
kexMutex.Lock() kexMutex.Lock()
...@@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) { ...@@ -39,12 +38,11 @@ func getEphermalKEX() (res crypto.KeyExchange) {
if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime { if kexCurrent == nil || time.Since(kexCurrentTime) > kexLifetime {
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
utils.Errorf("could not set KEX: %s", err.Error()) return nil, err
return kexCurrent
} }
kexCurrent = kex kexCurrent = kex
kexCurrentTime = time.Now() kexCurrentTime = time.Now()
return kexCurrent return kexCurrent, nil
} }
return kexCurrent return kexCurrent, nil
} }
...@@ -35,13 +35,8 @@ type MintTLS interface { ...@@ -35,13 +35,8 @@ type MintTLS interface {
SetCryptoStream(io.ReadWriter) SetCryptoStream(io.ReadWriter)
} }
// CryptoSetup is a crypto setup type baseCryptoSetup interface {
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
ConnectionState() ConnectionState ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
...@@ -49,6 +44,21 @@ type CryptoSetup interface { ...@@ -49,6 +44,21 @@ type CryptoSetup interface {
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ConnectionState records basic details about the QUIC connection. // ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon. // Warning: This API should not be considered stable and might change soon.
type ConnectionState struct { type ConnectionState struct {
......
...@@ -19,6 +19,8 @@ type extensionHandlerClient struct { ...@@ -19,6 +19,8 @@ type extensionHandlerClient struct {
initialVersion protocol.VersionNumber initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
version protocol.VersionNumber version protocol.VersionNumber
logger utils.Logger
} }
var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ mint.AppExtensionHandler = &extensionHandlerClient{}
...@@ -30,6 +32,7 @@ func NewExtensionHandlerClient( ...@@ -30,6 +32,7 @@ func NewExtensionHandlerClient(
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler { ) TLSExtensionHandler {
// The client reads the transport parameters from the Encrypted Extensions message. // The client reads the transport parameters from the Encrypted Extensions message.
// The paramsChan is used in the session's run loop's select statement. // The paramsChan is used in the session's run loop's select statement.
...@@ -41,6 +44,7 @@ func NewExtensionHandlerClient( ...@@ -41,6 +44,7 @@ func NewExtensionHandlerClient(
initialVersion: initialVersion, initialVersion: initialVersion,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
version: version, version: version,
logger: logger,
} }
} }
...@@ -49,7 +53,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi ...@@ -49,7 +53,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
return nil return nil
} }
utils.Debugf("Sending Transport Parameters: %s", h.ourParams) h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(clientHelloTransportParameters{ data, err := syntax.Marshal(clientHelloTransportParameters{
InitialVersion: uint32(h.initialVersion), InitialVersion: uint32(h.initialVersion),
Parameters: h.ourParams.getTransportParameters(), Parameters: h.ourParams.getTransportParameters(),
...@@ -122,7 +126,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte ...@@ -122,7 +126,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
if err != nil { if err != nil {
return err return err
} }
utils.Debugf("Received Transport Parameters: %s", params) h.logger.Debugf("Received Transport Parameters: %s", params)
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
......
...@@ -19,6 +19,8 @@ type extensionHandlerServer struct { ...@@ -19,6 +19,8 @@ type extensionHandlerServer struct {
version protocol.VersionNumber version protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
logger utils.Logger
} }
var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ mint.AppExtensionHandler = &extensionHandlerServer{}
...@@ -29,6 +31,7 @@ func NewExtensionHandlerServer( ...@@ -29,6 +31,7 @@ func NewExtensionHandlerServer(
params *TransportParameters, params *TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler { ) TLSExtensionHandler {
// Processing the ClientHello is performed statelessly (and from a single go-routine). // Processing the ClientHello is performed statelessly (and from a single go-routine).
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine. // Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
...@@ -38,6 +41,7 @@ func NewExtensionHandlerServer( ...@@ -38,6 +41,7 @@ func NewExtensionHandlerServer(
paramsChan: paramsChan, paramsChan: paramsChan,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
version: version, version: version,
logger: logger,
} }
} }
...@@ -56,7 +60,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi ...@@ -56,7 +60,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
for i, v := range supportedVersions { for i, v := range supportedVersions {
versions[i] = uint32(v) versions[i] = uint32(v)
} }
utils.Debugf("Sending Transport Parameters: %s", h.ourParams) h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
data, err := syntax.Marshal(encryptedExtensionsTransportParameters{ data, err := syntax.Marshal(encryptedExtensionsTransportParameters{
NegotiatedVersion: uint32(h.version), NegotiatedVersion: uint32(h.version),
SupportedVersions: versions, SupportedVersions: versions,
...@@ -108,7 +112,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte ...@@ -108,7 +112,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
if err != nil { if err != nil {
return err return err
} }
utils.Debugf("Received Transport Parameters: %s", params) h.logger.Debugf("Received Transport Parameters: %s", params)
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
......
...@@ -109,18 +109,6 @@ func (mr *MockSendAlgorithmMockRecorder) OnRetransmissionTimeout(arg0 interface{ ...@@ -109,18 +109,6 @@ func (mr *MockSendAlgorithmMockRecorder) OnRetransmissionTimeout(arg0 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithm)(nil).OnRetransmissionTimeout), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithm)(nil).OnRetransmissionTimeout), arg0)
} }
// RetransmissionDelay mocks base method
func (m *MockSendAlgorithm) RetransmissionDelay() time.Duration {
ret := m.ctrl.Call(m, "RetransmissionDelay")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// RetransmissionDelay indicates an expected call of RetransmissionDelay
func (mr *MockSendAlgorithmMockRecorder) RetransmissionDelay() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetransmissionDelay", reflect.TypeOf((*MockSendAlgorithm)(nil).RetransmissionDelay))
}
// SetNumEmulatedConnections mocks base method // SetNumEmulatedConnections mocks base method
func (m *MockSendAlgorithm) SetNumEmulatedConnections(arg0 int) { func (m *MockSendAlgorithm) SetNumEmulatedConnections(arg0 int) {
m.ctrl.Call(m, "SetNumEmulatedConnections", arg0) m.ctrl.Call(m, "SetNumEmulatedConnections", arg0)
......
// Generated by: main // This file was automatically generated by genny.
// TypeWriter: linkedlist // Any changes will be lost if this file is regenerated.
// Directive: +gen on ByteInterval // see https://github.com/cheekybits/genny
package utils package utils
// List is a modification of http://golang.org/pkg/container/list/ // Linked list implementation from the Go standard library.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// ByteIntervalElement is an element of a linked list. // ByteIntervalElement is an element of a linked list.
type ByteIntervalElement struct { type ByteIntervalElement struct {
...@@ -41,8 +38,7 @@ func (e *ByteIntervalElement) Prev() *ByteIntervalElement { ...@@ -41,8 +38,7 @@ func (e *ByteIntervalElement) Prev() *ByteIntervalElement {
return nil return nil
} }
// ByteIntervalList represents a doubly linked list. // ByteIntervalList is a linked list of ByteIntervals.
// The zero value for ByteIntervalList is an empty list ready to use.
type ByteIntervalList struct { type ByteIntervalList struct {
root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element len int // current list length excluding (this) sentinel element
...@@ -63,7 +59,7 @@ func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init ...@@ -63,7 +59,7 @@ func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init
// The complexity is O(1). // The complexity is O(1).
func (l *ByteIntervalList) Len() int { return l.len } func (l *ByteIntervalList) Len() int { return l.len }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil if the list is empty.
func (l *ByteIntervalList) Front() *ByteIntervalElement { func (l *ByteIntervalList) Front() *ByteIntervalElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -71,7 +67,7 @@ func (l *ByteIntervalList) Front() *ByteIntervalElement { ...@@ -71,7 +67,7 @@ func (l *ByteIntervalList) Front() *ByteIntervalElement {
return l.root.next return l.root.next
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil if the list is empty.
func (l *ByteIntervalList) Back() *ByteIntervalElement { func (l *ByteIntervalList) Back() *ByteIntervalElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -79,7 +75,7 @@ func (l *ByteIntervalList) Back() *ByteIntervalElement { ...@@ -79,7 +75,7 @@ func (l *ByteIntervalList) Back() *ByteIntervalElement {
return l.root.prev return l.root.prev
} }
// lazyInit lazily initializes a zero ByteIntervalList value. // lazyInit lazily initializes a zero List value.
func (l *ByteIntervalList) lazyInit() { func (l *ByteIntervalList) lazyInit() {
if l.root.next == nil { if l.root.next == nil {
l.Init() l.Init()
...@@ -98,7 +94,7 @@ func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalEleme ...@@ -98,7 +94,7 @@ func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalEleme
return e return e
} }
// insertValue is a convenience wrapper for insert(&ByteIntervalElement{Value: v}, at). // insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement {
return l.insert(&ByteIntervalElement{Value: v}, at) return l.insert(&ByteIntervalElement{Value: v}, at)
} }
...@@ -116,10 +112,11 @@ func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { ...@@ -116,10 +112,11 @@ func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement {
// Remove removes e from l if e is an element of list l. // Remove removes e from l if e is an element of list l.
// It returns the element value e.Value. // It returns the element value e.Value.
// The element must not be nil.
func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval {
if e.list == l { if e.list == l {
// if e.list == l, l must have been initialized when e was inserted // if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero ByteIntervalElement) and l.remove will crash // in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e) l.remove(e)
} }
return e.Value return e.Value
...@@ -139,46 +136,51 @@ func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { ...@@ -139,46 +136,51 @@ func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e. // InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in ByteIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev) return l.insertValue(v, mark.prev)
} }
// InsertAfter inserts a new element e with value v immediately after mark and returns e. // InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in ByteIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark) return l.insertValue(v, mark)
} }
// MoveToFront moves element e to the front of list l. // MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) {
if e.list != l || l.root.next == e { if e.list != l || l.root.next == e {
return return
} }
// see comment in ByteIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root) l.insert(l.remove(e), &l.root)
} }
// MoveToBack moves element e to the back of list l. // MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) {
if e.list != l || l.root.prev == e { if e.list != l || l.root.prev == e {
return return
} }
// see comment in ByteIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark. // MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -187,7 +189,8 @@ func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { ...@@ -187,7 +189,8 @@ func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) {
} }
// MoveAfter moves element e to its new position after mark. // MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -196,7 +199,7 @@ func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { ...@@ -196,7 +199,7 @@ func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) {
} }
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
...@@ -205,7 +208,7 @@ func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { ...@@ -205,7 +208,7 @@ func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) {
} }
// PushFrontList inserts a copy of an other list at the front of list l. // PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
......
package utils
//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval
//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval
package linkedlist
import "github.com/cheekybits/genny/generic"
// Linked list implementation from the Go standard library.
// Item is a generic type.
type Item generic.Type
// ItemElement is an element of a linked list.
type ItemElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *ItemElement
// The list to which this element belongs.
list *ItemList
// The value stored with this element.
Value Item
}
// Next returns the next list element or nil.
func (e *ItemElement) Next() *ItemElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *ItemElement) Prev() *ItemElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// ItemList is a linked list of Items.
type ItemList struct {
root ItemElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *ItemList) Init() *ItemList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewItemList returns an initialized list.
func NewItemList() *ItemList { return new(ItemList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *ItemList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *ItemList) Front() *ItemElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *ItemList) Back() *ItemElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *ItemList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *ItemList) insert(e, at *ItemElement) *ItemElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *ItemList) insertValue(v Item, at *ItemElement) *ItemElement {
return l.insert(&ItemElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *ItemList) remove(e *ItemElement) *ItemElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *ItemList) Remove(e *ItemElement) Item {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *ItemList) PushFront(v Item) *ItemElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *ItemList) PushBack(v Item) *ItemElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ItemList) InsertBefore(v Item, mark *ItemElement) *ItemElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ItemList) InsertAfter(v Item, mark *ItemElement) *ItemElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ItemList) MoveToFront(e *ItemElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ItemList) MoveToBack(e *ItemElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ItemList) MoveBefore(e, mark *ItemElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ItemList) MoveAfter(e, mark *ItemElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ItemList) PushBackList(other *ItemList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ItemList) PushFrontList(other *ItemList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
// LogLevel of quic-go // LogLevel of quic-go
type LogLevel uint8 type LogLevel uint8
const logEnv = "QUIC_GO_LOG_LEVEL"
const ( const (
// LogLevelNothing disables // LogLevelNothing disables
LogLevelNothing LogLevel = iota LogLevelNothing LogLevel = iota
...@@ -24,72 +22,92 @@ const ( ...@@ -24,72 +22,92 @@ const (
LogLevelDebug LogLevelDebug
) )
var ( const logEnv = "QUIC_GO_LOG_LEVEL"
logLevel = LogLevelNothing
timeFormat = "" // A Logger logs.
) type Logger interface {
SetLogLevel(LogLevel)
SetLogTimeFormat(format string)
Debug() bool
Errorf(format string, args ...interface{})
Infof(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// DefaultLogger is used by quic-go for logging.
var DefaultLogger Logger
type defaultLogger struct {
logLevel LogLevel
timeFormat string
}
var _ Logger = &defaultLogger{}
// SetLogLevel sets the log level // SetLogLevel sets the log level
func SetLogLevel(level LogLevel) { func (l *defaultLogger) SetLogLevel(level LogLevel) {
logLevel = level l.logLevel = level
} }
// SetLogTimeFormat sets the format of the timestamp // SetLogTimeFormat sets the format of the timestamp
// an empty string disables the logging of timestamps // an empty string disables the logging of timestamps
func SetLogTimeFormat(format string) { func (l *defaultLogger) SetLogTimeFormat(format string) {
log.SetFlags(0) // disable timestamp logging done by the log package log.SetFlags(0) // disable timestamp logging done by the log package
timeFormat = format l.timeFormat = format
} }
// Debugf logs something // Debugf logs something
func Debugf(format string, args ...interface{}) { func (l *defaultLogger) Debugf(format string, args ...interface{}) {
if logLevel == LogLevelDebug { if l.logLevel == LogLevelDebug {
logMessage(format, args...) l.logMessage(format, args...)
} }
} }
// Infof logs something // Infof logs something
func Infof(format string, args ...interface{}) { func (l *defaultLogger) Infof(format string, args ...interface{}) {
if logLevel >= LogLevelInfo { if l.logLevel >= LogLevelInfo {
logMessage(format, args...) l.logMessage(format, args...)
} }
} }
// Errorf logs something // Errorf logs something
func Errorf(format string, args ...interface{}) { func (l *defaultLogger) Errorf(format string, args ...interface{}) {
if logLevel >= LogLevelError { if l.logLevel >= LogLevelError {
logMessage(format, args...) l.logMessage(format, args...)
} }
} }
func logMessage(format string, args ...interface{}) { func (l *defaultLogger) logMessage(format string, args ...interface{}) {
if len(timeFormat) > 0 { if len(l.timeFormat) > 0 {
log.Printf(time.Now().Format(timeFormat)+" "+format, args...) log.Printf(time.Now().Format(l.timeFormat)+" "+format, args...)
} else { } else {
log.Printf(format, args...) log.Printf(format, args...)
} }
} }
// Debug returns true if the log level is LogLevelDebug // Debug returns true if the log level is LogLevelDebug
func Debug() bool { func (l *defaultLogger) Debug() bool {
return logLevel == LogLevelDebug return l.logLevel == LogLevelDebug
} }
func init() { func init() {
readLoggingEnv() DefaultLogger = &defaultLogger{}
DefaultLogger.SetLogLevel(readLoggingEnv())
} }
func readLoggingEnv() { func readLoggingEnv() LogLevel {
switch strings.ToLower(os.Getenv(logEnv)) { switch strings.ToLower(os.Getenv(logEnv)) {
case "": case "":
return return LogLevelNothing
case "debug": case "debug":
logLevel = LogLevelDebug return LogLevelDebug
case "info": case "info":
logLevel = LogLevelInfo return LogLevelInfo
case "error": case "error":
logLevel = LogLevelError return LogLevelError
default: default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
return LogLevelNothing
} }
} }
...@@ -3,7 +3,6 @@ package utils ...@@ -3,7 +3,6 @@ package utils
import "github.com/lucas-clemente/quic-go/internal/protocol" import "github.com/lucas-clemente/quic-go/internal/protocol"
// PacketInterval is an interval from one PacketNumber to the other // PacketInterval is an interval from one PacketNumber to the other
// +gen linkedlist
type PacketInterval struct { type PacketInterval struct {
Start protocol.PacketNumber Start protocol.PacketNumber
End protocol.PacketNumber End protocol.PacketNumber
......
// Generated by: main // This file was automatically generated by genny.
// TypeWriter: linkedlist // Any changes will be lost if this file is regenerated.
// Directive: +gen on PacketInterval // see https://github.com/cheekybits/genny
package utils package utils
// List is a modification of http://golang.org/pkg/container/list/ // Linked list implementation from the Go standard library.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// PacketIntervalElement is an element of a linked list. // PacketIntervalElement is an element of a linked list.
type PacketIntervalElement struct { type PacketIntervalElement struct {
...@@ -41,8 +38,7 @@ func (e *PacketIntervalElement) Prev() *PacketIntervalElement { ...@@ -41,8 +38,7 @@ func (e *PacketIntervalElement) Prev() *PacketIntervalElement {
return nil return nil
} }
// PacketIntervalList represents a doubly linked list. // PacketIntervalList is a linked list of PacketIntervals.
// The zero value for PacketIntervalList is an empty list ready to use.
type PacketIntervalList struct { type PacketIntervalList struct {
root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element len int // current list length excluding (this) sentinel element
...@@ -63,7 +59,7 @@ func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList ...@@ -63,7 +59,7 @@ func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList
// The complexity is O(1). // The complexity is O(1).
func (l *PacketIntervalList) Len() int { return l.len } func (l *PacketIntervalList) Len() int { return l.len }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil if the list is empty.
func (l *PacketIntervalList) Front() *PacketIntervalElement { func (l *PacketIntervalList) Front() *PacketIntervalElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -71,7 +67,7 @@ func (l *PacketIntervalList) Front() *PacketIntervalElement { ...@@ -71,7 +67,7 @@ func (l *PacketIntervalList) Front() *PacketIntervalElement {
return l.root.next return l.root.next
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil if the list is empty.
func (l *PacketIntervalList) Back() *PacketIntervalElement { func (l *PacketIntervalList) Back() *PacketIntervalElement {
if l.len == 0 { if l.len == 0 {
return nil return nil
...@@ -79,7 +75,7 @@ func (l *PacketIntervalList) Back() *PacketIntervalElement { ...@@ -79,7 +75,7 @@ func (l *PacketIntervalList) Back() *PacketIntervalElement {
return l.root.prev return l.root.prev
} }
// lazyInit lazily initializes a zero PacketIntervalList value. // lazyInit lazily initializes a zero List value.
func (l *PacketIntervalList) lazyInit() { func (l *PacketIntervalList) lazyInit() {
if l.root.next == nil { if l.root.next == nil {
l.Init() l.Init()
...@@ -98,7 +94,7 @@ func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketInterva ...@@ -98,7 +94,7 @@ func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketInterva
return e return e
} }
// insertValue is a convenience wrapper for insert(&PacketIntervalElement{Value: v}, at). // insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement {
return l.insert(&PacketIntervalElement{Value: v}, at) return l.insert(&PacketIntervalElement{Value: v}, at)
} }
...@@ -116,10 +112,11 @@ func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalEle ...@@ -116,10 +112,11 @@ func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalEle
// Remove removes e from l if e is an element of list l. // Remove removes e from l if e is an element of list l.
// It returns the element value e.Value. // It returns the element value e.Value.
// The element must not be nil.
func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval {
if e.list == l { if e.list == l {
// if e.list == l, l must have been initialized when e was inserted // if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero PacketIntervalElement) and l.remove will crash // in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e) l.remove(e)
} }
return e.Value return e.Value
...@@ -139,46 +136,51 @@ func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { ...@@ -139,46 +136,51 @@ func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement {
// InsertBefore inserts a new element e with value v immediately before mark and returns e. // InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev) return l.insertValue(v, mark.prev)
} }
// InsertAfter inserts a new element e with value v immediately after mark and returns e. // InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified. // If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement {
if mark.list != l { if mark.list != l {
return nil return nil
} }
// see comment in PacketIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
return l.insertValue(v, mark) return l.insertValue(v, mark)
} }
// MoveToFront moves element e to the front of list l. // MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) {
if e.list != l || l.root.next == e { if e.list != l || l.root.next == e {
return return
} }
// see comment in PacketIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root) l.insert(l.remove(e), &l.root)
} }
// MoveToBack moves element e to the back of list l. // MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified. // If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) {
if e.list != l || l.root.prev == e { if e.list != l || l.root.prev == e {
return return
} }
// see comment in PacketIntervalList.Remove about initialization of l // see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark. // MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -187,7 +189,8 @@ func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { ...@@ -187,7 +189,8 @@ func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) {
} }
// MoveAfter moves element e to its new position after mark. // MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified. // If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) {
if e.list != l || e == mark || mark.list != l { if e.list != l || e == mark || mark.list != l {
return return
...@@ -196,7 +199,7 @@ func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { ...@@ -196,7 +199,7 @@ func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) {
} }
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
...@@ -205,7 +208,7 @@ func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { ...@@ -205,7 +208,7 @@ func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) {
} }
// PushFrontList inserts a copy of an other list at the front of list l. // PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. // The lists l and other may be the same. They must not be nil.
func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) {
l.lazyInit() l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
......
...@@ -3,7 +3,6 @@ package utils ...@@ -3,7 +3,6 @@ package utils
import "github.com/lucas-clemente/quic-go/internal/protocol" import "github.com/lucas-clemente/quic-go/internal/protocol"
// ByteInterval is an interval from one ByteCount to the other // ByteInterval is an interval from one ByteCount to the other
// +gen linkedlist
type ByteInterval struct { type ByteInterval struct {
Start protocol.ByteCount Start protocol.ByteCount
End protocol.ByteCount End protocol.ByteCount
......
...@@ -24,8 +24,8 @@ type AckFrame struct { ...@@ -24,8 +24,8 @@ type AckFrame struct {
DelayTime time.Duration DelayTime time.Duration
} }
// ParseAckFrame reads an ACK frame // parseAckFrame reads an ACK frame
func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) { func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return parseAckFrameLegacy(r, version) return parseAckFrameLegacy(r, version)
} }
......
...@@ -12,8 +12,8 @@ type BlockedFrame struct { ...@@ -12,8 +12,8 @@ type BlockedFrame struct {
Offset protocol.ByteCount Offset protocol.ByteCount
} }
// ParseBlockedFrame parses a BLOCKED frame // parseBlockedFrame parses a BLOCKED frame
func ParseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) { func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
return nil, err return nil, err
} }
......
...@@ -11,11 +11,11 @@ type blockedFrameLegacy struct { ...@@ -11,11 +11,11 @@ type blockedFrameLegacy struct {
StreamID protocol.StreamID StreamID protocol.StreamID
} }
// ParseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format) // parseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format)
// The frame returned is // The frame returned is
// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream // * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream
// * a BLOCKED frame, if the BLOCKED applies to the connection // * a BLOCKED frame, if the BLOCKED applies to the connection
func ParseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { func parseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
...@@ -17,8 +17,8 @@ type ConnectionCloseFrame struct { ...@@ -17,8 +17,8 @@ type ConnectionCloseFrame struct {
ReasonPhrase string ReasonPhrase string
} }
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame // parseConnectionCloseFrame reads a CONNECTION_CLOSE frame
func ParseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) { func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
package wire
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
// ParseNextFrame parses the next frame
// It skips PADDING frames.
func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) {
if r.Len() == 0 {
return nil, nil
}
typeByte, _ := r.ReadByte()
if typeByte == 0x0 { // PADDING frame
return ParseNextFrame(r, hdr, v)
}
r.UnreadByte()
if !v.UsesIETFFrameFormat() {
return parseGQUICFrame(r, typeByte, hdr, v)
}
return parseIETFFrame(r, typeByte, v)
}
func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
var frame Frame
var err error
if typeByte&0xf8 == 0x10 {
frame, err = parseStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error())
}
return frame, err
}
// TODO: implement all IETF QUIC frame types
switch typeByte {
case 0x1:
frame, err = parseRstStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = parseConnectionCloseFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x4:
frame, err = parseMaxDataFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x5:
frame, err = parseMaxStreamDataFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
frame, err = parseMaxStreamIDFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = parsePingFrame(r, v)
case 0x8:
frame, err = parseBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9:
frame, err = parseStreamBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0xa:
frame, err = parseStreamIDBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xc:
frame, err = parseStopSendingFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xe:
frame, err = parseAckFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
}
return frame, err
}
func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.VersionNumber) (Frame, error) {
var frame Frame
var err error
if typeByte&0x80 == 0x80 {
frame, err = parseStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error())
}
return frame, err
} else if typeByte&0xc0 == 0x40 {
frame, err = parseAckFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
return frame, err
}
switch typeByte {
case 0x1:
frame, err = parseRstStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = parseConnectionCloseFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x3:
frame, err = parseGoawayFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidGoawayData, err.Error())
}
case 0x4:
frame, err = parseWindowUpdateFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x5:
frame, err = parseBlockedFrameLegacy(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x6:
frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v)
if err != nil {
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
}
case 0x7:
frame, err = parsePingFrame(r, v)
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
}
return frame, err
}
...@@ -16,8 +16,8 @@ type GoawayFrame struct { ...@@ -16,8 +16,8 @@ type GoawayFrame struct {
ReasonPhrase string ReasonPhrase string
} }
// ParseGoawayFrame parses a GOAWAY frame // parseGoawayFrame parses a GOAWAY frame
func ParseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) { func parseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) {
frame := &GoawayFrame{} frame := &GoawayFrame{}
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// Header is the header of a QUIC packet. // Header is the header of a QUIC packet.
...@@ -103,10 +104,10 @@ func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNu ...@@ -103,10 +104,10 @@ func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNu
} }
// Log logs the Header // Log logs the Header
func (h *Header) Log() { func (h *Header) Log(logger utils.Logger) {
if h.isPublicHeader { if h.isPublicHeader {
h.logPublicHeader() h.logPublicHeader(logger)
} else { } else {
h.logHeader() h.logHeader(logger)
} }
} }
...@@ -174,14 +174,14 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) { ...@@ -174,14 +174,14 @@ func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
return length, nil return length, nil
} }
func (h *Header) logHeader() { func (h *Header) logHeader(logger utils.Logger) {
if h.IsLongHeader { if h.IsLongHeader {
utils.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version) logger.Debugf(" Long Header{Type: %s, ConnectionID: %#x, PacketNumber: %#x, Version: %s}", h.Type, h.ConnectionID, h.PacketNumber, h.Version)
} else { } else {
connID := "(omitted)" connID := "(omitted)"
if !h.OmitConnectionID { if !h.OmitConnectionID {
connID = fmt.Sprintf("%#x", h.ConnectionID) connID = fmt.Sprintf("%#x", h.ConnectionID)
} }
utils.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) logger.Debugf(" Short Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", connID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
} }
} }
...@@ -3,8 +3,8 @@ package wire ...@@ -3,8 +3,8 @@ package wire
import "github.com/lucas-clemente/quic-go/internal/utils" import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received // LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) { func LogFrame(logger utils.Logger, frame Frame, sent bool) {
if !utils.Debug() { if !logger.Debug() {
return return
} }
dir := "<-" dir := "<-"
...@@ -13,16 +13,16 @@ func LogFrame(frame Frame, sent bool) { ...@@ -13,16 +13,16 @@ func LogFrame(frame Frame, sent bool) {
} }
switch f := frame.(type) { switch f := frame.(type) {
case *StreamFrame: case *StreamFrame:
utils.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen()) logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame: case *StopWaitingFrame:
if sent { if sent {
utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen) logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else { } else {
utils.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked) logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
} }
case *AckFrame: case *AckFrame:
utils.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String()) logger.Debugf("\t%s &wire.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
default: default:
utils.Debugf("\t%s %#v", dir, frame) logger.Debugf("\t%s %#v", dir, frame)
} }
} }
...@@ -12,8 +12,8 @@ type MaxDataFrame struct { ...@@ -12,8 +12,8 @@ type MaxDataFrame struct {
ByteOffset protocol.ByteCount ByteOffset protocol.ByteCount
} }
// ParseMaxDataFrame parses a MAX_DATA frame // parseMaxDataFrame parses a MAX_DATA frame
func ParseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) { func parseMaxDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxDataFrame, error) {
// read the TypeByte // read the TypeByte
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
return nil, err return nil, err
......
...@@ -13,8 +13,8 @@ type MaxStreamDataFrame struct { ...@@ -13,8 +13,8 @@ type MaxStreamDataFrame struct {
ByteOffset protocol.ByteCount ByteOffset protocol.ByteCount
} }
// ParseMaxStreamDataFrame parses a MAX_STREAM_DATA frame // parseMaxStreamDataFrame parses a MAX_STREAM_DATA frame
func ParseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) { func parseMaxStreamDataFrame(r *bytes.Reader, version protocol.VersionNumber) (*MaxStreamDataFrame, error) {
frame := &MaxStreamDataFrame{} frame := &MaxStreamDataFrame{}
// read the TypeByte // read the TypeByte
......
...@@ -12,8 +12,8 @@ type MaxStreamIDFrame struct { ...@@ -12,8 +12,8 @@ type MaxStreamIDFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
} }
// ParseMaxStreamIDFrame parses a MAX_STREAM_ID frame // parseMaxStreamIDFrame parses a MAX_STREAM_ID frame
func ParseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) { func parseMaxStreamIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamIDFrame, error) {
// read the Type byte // read the Type byte
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
return nil, err return nil, err
......
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
// A PingFrame is a ping frame // A PingFrame is a ping frame
type PingFrame struct{} type PingFrame struct{}
// ParsePingFrame parses a Ping frame // parsePingFrame parses a Ping frame
func ParsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) { func parsePingFrame(r *bytes.Reader, version protocol.VersionNumber) (*PingFrame, error) {
frame := &PingFrame{} frame := &PingFrame{}
_, err := r.ReadByte() _, err := r.ReadByte()
......
...@@ -20,6 +20,9 @@ var ( ...@@ -20,6 +20,9 @@ var (
// writePublicHeader writes a Public Header. // writePublicHeader writes a Public Header.
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error { func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
if h.VersionFlag && pers == protocol.PerspectiveServer {
return errors.New("PublicHeader: Writing of Version Negotiation Packets not supported")
}
if h.VersionFlag && h.ResetFlag { if h.VersionFlag && h.ResetFlag {
return errResetAndVersionFlagSet return errResetAndVersionFlagSet
} }
...@@ -228,7 +231,7 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { ...@@ -228,7 +231,7 @@ func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool {
return true return true
} }
func (h *Header) logPublicHeader() { func (h *Header) logPublicHeader(logger utils.Logger) {
connID := "(omitted)" connID := "(omitted)"
if !h.OmitConnectionID { if !h.OmitConnectionID {
connID = fmt.Sprintf("%#x", h.ConnectionID) connID = fmt.Sprintf("%#x", h.ConnectionID)
...@@ -237,5 +240,5 @@ func (h *Header) logPublicHeader() { ...@@ -237,5 +240,5 @@ func (h *Header) logPublicHeader() {
if h.Version != 0 { if h.Version != 0 {
ver = h.Version.String() ver = h.Version.String()
} }
utils.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce) logger.Debugf(" Public Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", connID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
} }
...@@ -16,8 +16,8 @@ type RstStreamFrame struct { ...@@ -16,8 +16,8 @@ type RstStreamFrame struct {
ByteOffset protocol.ByteCount ByteOffset protocol.ByteCount
} }
// ParseRstStreamFrame parses a RST_STREAM frame // parseRstStreamFrame parses a RST_STREAM frame
func ParseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) { func parseRstStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*RstStreamFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
...@@ -13,8 +13,8 @@ type StopSendingFrame struct { ...@@ -13,8 +13,8 @@ type StopSendingFrame struct {
ErrorCode protocol.ApplicationErrorCode ErrorCode protocol.ApplicationErrorCode
} }
// ParseStopSendingFrame parses a STOP_SENDING frame // parseStopSendingFrame parses a STOP_SENDING frame
func ParseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
...@@ -56,8 +56,8 @@ func (f *StopWaitingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { ...@@ -56,8 +56,8 @@ func (f *StopWaitingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + protocol.ByteCount(f.PacketNumberLen) return 1 + protocol.ByteCount(f.PacketNumberLen)
} }
// ParseStopWaitingFrame parses a StopWaiting frame // parseStopWaitingFrame parses a StopWaiting frame
func ParseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, _ protocol.VersionNumber) (*StopWaitingFrame, error) { func parseStopWaitingFrame(r *bytes.Reader, packetNumber protocol.PacketNumber, packetNumberLen protocol.PacketNumberLen, _ protocol.VersionNumber) (*StopWaitingFrame, error) {
frame := &StopWaitingFrame{} frame := &StopWaitingFrame{}
// read the TypeByte // read the TypeByte
......
...@@ -13,8 +13,8 @@ type StreamBlockedFrame struct { ...@@ -13,8 +13,8 @@ type StreamBlockedFrame struct {
Offset protocol.ByteCount Offset protocol.ByteCount
} }
// ParseStreamBlockedFrame parses a STREAM_BLOCKED frame // parseStreamBlockedFrame parses a STREAM_BLOCKED frame
func ParseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) { func parseStreamBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
...@@ -19,8 +19,8 @@ type StreamFrame struct { ...@@ -19,8 +19,8 @@ type StreamFrame struct {
Data []byte Data []byte
} }
// ParseStreamFrame reads a STREAM frame // parseStreamFrame reads a STREAM frame
func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) { func parseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamFrame, error) {
if !version.UsesIETFFrameFormat() { if !version.UsesIETFFrameFormat() {
return parseLegacyStreamFrame(r, version) return parseLegacyStreamFrame(r, version)
} }
...@@ -76,7 +76,8 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF ...@@ -76,7 +76,8 @@ func ParseStreamFrame(r *bytes.Reader, version protocol.VersionNumber) (*StreamF
if frame.Offset+frame.DataLen() > protocol.MaxByteCount { if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset") return nil, qerr.Error(qerr.InvalidStreamData, "data overflows maximum offset")
} }
if !frame.FinBit && frame.DataLen() == 0 { // empty frames are only allowed if they have offset 0 or the FIN bit set
if frame.DataLen() == 0 && !frame.FinBit && frame.Offset != 0 {
return nil, qerr.EmptyStreamFrameNoFin return nil, qerr.EmptyStreamFrameNoFin
} }
return frame, nil return frame, nil
......
...@@ -12,8 +12,8 @@ type StreamIDBlockedFrame struct { ...@@ -12,8 +12,8 @@ type StreamIDBlockedFrame struct {
StreamID protocol.StreamID StreamID protocol.StreamID
} }
// ParseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame // parseStreamIDBlockedFrame parses a STREAM_ID_BLOCKED frame
func ParseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) { func parseStreamIDBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamIDBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil { if _, err := r.ReadByte(); err != nil {
return nil, err return nil, err
} }
......
...@@ -10,17 +10,9 @@ import ( ...@@ -10,17 +10,9 @@ import (
// ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC // ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC
func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte {
buf := &bytes.Buffer{} buf := bytes.NewBuffer(make([]byte, 0, 1+8+len(versions)*4))
ph := Header{ buf.Write([]byte{0x1 | 0x8}) // type byte
ConnectionID: connID, utils.BigEndian.WriteUint64(buf, uint64(connID))
PacketNumber: 1,
VersionFlag: true,
IsVersionNegotiation: true,
}
if err := ph.writePublicHeader(buf, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error())
return nil
}
for _, v := range versions { for _, v := range versions {
utils.BigEndian.WriteUint32(buf, uint32(v)) utils.BigEndian.WriteUint32(buf, uint32(v))
} }
......
...@@ -12,11 +12,11 @@ type windowUpdateFrame struct { ...@@ -12,11 +12,11 @@ type windowUpdateFrame struct {
ByteOffset protocol.ByteCount ByteOffset protocol.ByteCount
} }
// ParseWindowUpdateFrame parses a WINDOW_UPDATE frame // parseWindowUpdateFrame parses a WINDOW_UPDATE frame
// The frame returned is // The frame returned is
// * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream // * a MAX_STREAM_DATA frame, if the WINDOW_UPDATE applies to a stream
// * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection // * a MAX_DATA frame, if the WINDOW_UPDATE applies to the connection
func ParseWindowUpdateFrame(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) { func parseWindowUpdateFrame(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err return nil, err
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"io" "io"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
...@@ -76,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf ...@@ -76,6 +77,7 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
mconf.ServerName = tlsConf.ServerName mconf.ServerName = tlsConf.ServerName
mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
mconf.RootCAs = tlsConf.RootCAs
mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate
for i, certChain := range tlsConf.Certificates { for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{ mconf.Certificates[i] = &mint.Certificate{
...@@ -106,39 +108,45 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf ...@@ -106,39 +108,45 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets // unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0. // These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) { func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) {
unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version} decrypted, err := aead.Open(data[:0], data, hdr.PacketNumber, hdr.Raw)
packet, err := unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var frame *wire.StreamFrame var frame *wire.StreamFrame
for _, f := range packet.frames { r := bytes.NewReader(decrypted)
for {
f, err := wire.ParseNextFrame(r, hdr, version)
if err != nil {
return nil, err
}
var ok bool var ok bool
frame, ok = f.(*wire.StreamFrame) if frame, ok = f.(*wire.StreamFrame); ok || frame == nil {
if ok {
break break
} }
} }
if frame == nil { if frame == nil {
return nil, errors.New("Packet doesn't contain a STREAM_FRAME") return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
} }
if frame.StreamID != version.CryptoStreamID() {
return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID)
}
// We don't need a check for the stream ID here. // We don't need a check for the stream ID here.
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream. // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
if frame.Offset != 0 { if frame.Offset != 0 {
return nil, errors.New("received stream data with non-zero offset") return nil, errors.New("received stream data with non-zero offset")
} }
if utils.Debug() { if logger.Debug() {
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID) logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
hdr.Log() hdr.Log(logger)
wire.LogFrame(frame, false) wire.LogFrame(logger, frame, false)
} }
return frame, nil return frame, nil
} }
// packUnencryptedPacket provides a low-overhead way to pack a packet. // packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) {
raw := *getPacketBuffer() raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0]) buffer := bytes.NewBuffer(raw[:0])
if err := hdr.Write(buffer, pers, hdr.Version); err != nil { if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
...@@ -151,10 +159,10 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, per ...@@ -151,10 +159,10 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, per
raw = raw[0:buffer.Len()] raw = raw[0:buffer.Len()]
_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex]) _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+aead.Overhead()] raw = raw[0 : buffer.Len()+aead.Overhead()]
if utils.Debug() { if logger.Debug() {
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
hdr.Log() hdr.Log(logger)
wire.LogFrame(f, true) wire.LogFrame(logger, f, true)
} }
return raw, nil return raw, nil
} }
...@@ -9,4 +9,8 @@ package quic ...@@ -9,4 +9,8 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go" //go:generate sh -c "sed -i '' 's/quic_go.//g' mock_stream_getter_test.go mock_stream_manager_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker"
//go:generate sh -c "sed -i '' 's/quic_go.//g' mock_unpacker_test.go mock_unpacker_test.go"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
//go:generate sh -c "goimports -w mock*_test.go" //go:generate sh -c "goimports -w mock*_test.go"
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"sync" "sync"
"time"
"github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
...@@ -28,9 +29,16 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { ...@@ -28,9 +29,16 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
Frames: p.frames, Frames: p.frames,
Length: protocol.ByteCount(len(p.raw)), Length: protocol.ByteCount(len(p.raw)),
EncryptionLevel: p.encryptionLevel, EncryptionLevel: p.encryptionLevel,
SendTime: time.Now(),
} }
} }
type sealingManager interface {
GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
}
type streamFrameSource interface { type streamFrameSource interface {
HasCryptoStreamData() bool HasCryptoStreamData() bool
PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame
...@@ -41,7 +49,8 @@ type packetPacker struct { ...@@ -41,7 +49,8 @@ type packetPacker struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
cryptoSetup handshake.CryptoSetup divNonce []byte
cryptoSetup sealingManager
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen
...@@ -62,7 +71,8 @@ func newPacketPacker(connectionID protocol.ConnectionID, ...@@ -62,7 +71,8 @@ func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
remoteAddr net.Addr, // only used for determining the max packet size remoteAddr net.Addr, // only used for determining the max packet size
cryptoSetup handshake.CryptoSetup, divNonce []byte,
cryptoSetup sealingManager,
streamFramer streamFrameSource, streamFramer streamFrameSource,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
...@@ -82,6 +92,7 @@ func newPacketPacker(connectionID protocol.ConnectionID, ...@@ -82,6 +92,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
} }
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
divNonce: divNonce,
connectionID: connectionID, connectionID: connectionID,
perspective: perspective, perspective: perspective,
version: version, version: version,
...@@ -455,7 +466,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header ...@@ -455,7 +466,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
} }
if !p.version.UsesTLS() { if !p.version.UsesTLS() {
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
header.DiversificationNonce = p.cryptoSetup.DiversificationNonce() header.DiversificationNonce = p.divNonce
} }
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
header.VersionFlag = true header.VersionFlag = true
......
...@@ -2,7 +2,6 @@ package quic ...@@ -2,7 +2,6 @@ package quic
import ( import (
"bytes" "bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
...@@ -14,188 +13,115 @@ type unpackedPacket struct { ...@@ -14,188 +13,115 @@ type unpackedPacket struct {
frames []wire.Frame frames []wire.Frame
} }
type quicAEAD interface { type gQUICAEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
} }
type packetUnpacker struct { type quicAEAD interface {
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
type packetUnpackerBase struct {
version protocol.VersionNumber version protocol.VersionNumber
aead quicAEAD
} }
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { func (u *packetUnpackerBase) parseFrames(decrypted []byte, hdr *wire.Header) ([]wire.Frame, error) {
buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)
decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary)
if err != nil {
// Wrap err in quicError so that public reset is sent by session
return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
}
r := bytes.NewReader(decrypted) r := bytes.NewReader(decrypted)
if r.Len() == 0 { if r.Len() == 0 {
return nil, qerr.MissingPayload return nil, qerr.MissingPayload
} }
fs := make([]wire.Frame, 0, 2) fs := make([]wire.Frame, 0, 2)
// Read all frames in the packet // Read all frames in the packet
for r.Len() > 0 { for {
typeByte, _ := r.ReadByte() frame, err := wire.ParseNextFrame(r, hdr, u.version)
if typeByte == 0x0 { // PADDING frame
continue
}
r.UnreadByte()
frame, err := u.parseFrame(r, typeByte, hdr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if sf, ok := frame.(*wire.StreamFrame); ok { if frame == nil {
if sf.StreamID != u.version.CryptoStreamID() && encryptionLevel <= protocol.EncryptionUnencrypted { break
return nil, qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", sf.StreamID))
} }
}
if frame != nil {
fs = append(fs, frame) fs = append(fs, frame)
} }
} return fs, nil
}
return &unpackedPacket{ // The packetUnpackerGQUIC unpacks gQUIC packets.
encryptionLevel: encryptionLevel, type packetUnpackerGQUIC struct {
frames: fs, packetUnpackerBase
}, nil aead gQUICAEAD
} }
func (u *packetUnpacker) parseFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { var _ unpacker = &packetUnpackerGQUIC{}
if u.version.UsesIETFFrameFormat() {
return u.parseIETFFrame(r, typeByte, hdr) func newPacketUnpackerGQUIC(aead gQUICAEAD, version protocol.VersionNumber) unpacker {
return &packetUnpackerGQUIC{
packetUnpackerBase: packetUnpackerBase{version: version},
aead: aead,
} }
return u.parseGQUICFrame(r, typeByte, hdr)
} }
func (u *packetUnpacker) parseIETFFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { func (u *packetUnpackerGQUIC) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var frame wire.Frame decrypted, encryptionLevel, err := u.aead.Open(data[:0], data, hdr.PacketNumber, headerBinary)
var err error
if typeByte&0xf8 == 0x10 {
frame, err = wire.ParseStreamFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error())
}
return frame, err
}
// TODO: implement all IETF QUIC frame types
switch typeByte {
case 0x1:
frame, err = wire.ParseRstStreamFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = wire.ParseConnectionCloseFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x4:
frame, err = wire.ParseMaxDataFrame(r, u.version)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error()) // Wrap err in quicError so that public reset is sent by session
} return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
case 0x5:
frame, err = wire.ParseMaxStreamDataFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
frame, err = wire.ParseMaxStreamIDFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = wire.ParsePingFrame(r, u.version)
case 0x8:
frame, err = wire.ParseBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9:
frame, err = wire.ParseStreamBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0xa:
frame, err = wire.ParseStreamIDBlockedFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xc:
frame, err = wire.ParseStopSendingFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
} }
case 0xe:
frame, err = wire.ParseAckFrame(r, u.version) fs, err := u.parseFrames(decrypted, hdr)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error()) return nil, err
} }
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte)) return &unpackedPacket{
encryptionLevel: encryptionLevel,
frames: fs,
}, nil
}
// The packetUnpacker unpacks IETF QUIC packets.
type packetUnpacker struct {
packetUnpackerBase
aead quicAEAD
}
var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
return &packetUnpacker{
packetUnpackerBase: packetUnpackerBase{version: version},
aead: aead,
} }
return frame, err
} }
func (u *packetUnpacker) parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *wire.Header) (wire.Frame, error) { func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var frame wire.Frame buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)
var decrypted []byte
var encryptionLevel protocol.EncryptionLevel
var err error var err error
if typeByte&0x80 == 0x80 { if hdr.IsLongHeader {
frame, err = wire.ParseStreamFrame(r, u.version) decrypted, err = u.aead.OpenHandshake(buf, data, hdr.PacketNumber, headerBinary)
if err != nil { encryptionLevel = protocol.EncryptionUnencrypted
err = qerr.Error(qerr.InvalidStreamData, err.Error()) } else {
} decrypted, err = u.aead.Open1RTT(buf, data, hdr.PacketNumber, headerBinary)
return frame, err encryptionLevel = protocol.EncryptionForwardSecure
} else if typeByte&0xc0 == 0x40 {
frame, err = wire.ParseAckFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
return frame, err
}
switch typeByte {
case 0x1:
frame, err = wire.ParseRstStreamFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = wire.ParseConnectionCloseFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x3:
frame, err = wire.ParseGoawayFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidGoawayData, err.Error())
}
case 0x4:
frame, err = wire.ParseWindowUpdateFrame(r, u.version)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
} }
case 0x5:
frame, err = wire.ParseBlockedFrameLegacy(r, u.version)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error()) // Wrap err in quicError so that public reset is sent by session
return nil, qerr.Error(qerr.DecryptionFailure, err.Error())
} }
case 0x6:
frame, err = wire.ParseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, u.version) fs, err := u.parseFrames(decrypted, hdr)
if err != nil { if err != nil {
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error()) return nil, err
}
case 0x7:
frame, err = wire.ParsePingFrame(r, u.version)
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
} }
return frame, err
return &unpackedPacket{
encryptionLevel: encryptionLevel,
frames: fs,
}, nil
} }
...@@ -2,8 +2,6 @@ package qerr ...@@ -2,8 +2,6 @@ package qerr
import ( import (
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// ErrorCode can be used as a normal error without reason. // ErrorCode can be used as a normal error without reason.
...@@ -51,6 +49,5 @@ func ToQuicError(err error) *QuicError { ...@@ -51,6 +49,5 @@ func ToQuicError(err error) *QuicError {
case ErrorCode: case ErrorCode:
return Error(e, "") return Error(e, "")
} }
utils.Errorf("Internal error: %v", err)
return Error(InternalError, err.Error()) return Error(InternalError, err.Error())
} }
...@@ -50,8 +50,10 @@ type server struct { ...@@ -50,8 +50,10 @@ type server struct {
errorChan chan struct{} errorChan chan struct{}
// set as members, so they can be set in the tests // set as members, so they can be set in the tests
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, error) newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config, logger utils.Logger) (packetHandler, error)
deleteClosedSessionsAfter time.Duration deleteClosedSessionsAfter time.Duration
logger utils.Logger
} }
var _ Listener = &server{} var _ Listener = &server{}
...@@ -110,6 +112,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, ...@@ -110,6 +112,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
sessionQueue: make(chan Session, 5), sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
supportsTLS: supportsTLS, supportsTLS: supportsTLS,
logger: utils.DefaultLogger,
} }
if supportsTLS { if supportsTLS {
if err := s.setupTLS(); err != nil { if err := s.setupTLS(); err != nil {
...@@ -117,16 +120,16 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, ...@@ -117,16 +120,16 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
} }
} }
go s.serve() go s.serve()
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil return s, nil
} }
func (s *server) setupTLS() error { func (s *server) setupTLS() error {
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie) cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger)
if err != nil { if err != nil {
return err return err
} }
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf) serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf, s.logger)
if err != nil { if err != nil {
return err return err
} }
...@@ -245,7 +248,7 @@ func (s *server) serve() { ...@@ -245,7 +248,7 @@ func (s *server) serve() {
} }
data = data[:n] data = data[:n]
if err := s.handlePacket(s.conn, remoteAddr, data); err != nil { if err := s.handlePacket(s.conn, remoteAddr, data); err != nil {
utils.Errorf("error handling packet: %s", err.Error()) s.logger.Errorf("error handling packet: %s", err.Error())
} }
} }
} }
...@@ -328,12 +331,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet ...@@ -328,12 +331,12 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
var pr *wire.PublicReset var pr *wire.PublicReset
pr, err = wire.ParsePublicReset(r) pr, err = wire.ParsePublicReset(r)
if err != nil { if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID) s.logger.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.", hdr.ConnectionID)
} else { } else {
utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber) s.logger.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.RejectedPacketNumber)
} }
} else { } else {
utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID) s.logger.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID)
} }
return nil return nil
} }
...@@ -360,7 +363,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet ...@@ -360,7 +363,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
if len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) { if len(packet) < protocol.MinClientHelloSize+len(hdr.Raw) {
return errors.New("dropping small packet with unknown version") return errors.New("dropping small packet with unknown version")
} }
utils.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
_, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr)
return err return err
} }
...@@ -377,7 +380,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet ...@@ -377,7 +380,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return errors.New("Server BUG: negotiated version not supported") return errors.New("Server BUG: negotiated version not supported")
} }
utils.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr) s.logger.Infof("Serving new connection: %x, version %s from %v", hdr.ConnectionID, version, remoteAddr)
session, err = s.newSession( session, err = s.newSession(
&conn{pconn: pconn, currentAddr: remoteAddr}, &conn{pconn: pconn, currentAddr: remoteAddr},
version, version,
...@@ -385,6 +388,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet ...@@ -385,6 +388,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.scfg, s.scfg,
s.tlsConf, s.tlsConf,
s.config, s.config,
s.logger,
) )
if err != nil { if err != nil {
return err return err
......
...@@ -21,9 +21,12 @@ type nullAEAD struct { ...@@ -21,9 +21,12 @@ type nullAEAD struct {
var _ quicAEAD = &nullAEAD{} var _ quicAEAD = &nullAEAD{}
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { func (n *nullAEAD) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
data, err := n.aead.Open(dst, src, packetNumber, associatedData) return n.aead.Open(dst, src, packetNumber, associatedData)
return data, protocol.EncryptionUnencrypted, err }
func (n *nullAEAD) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return nil, errors.New("no 1-RTT keys")
} }
type tlsSession struct { type tlsSession struct {
...@@ -40,6 +43,8 @@ type serverTLS struct { ...@@ -40,6 +43,8 @@ type serverTLS struct {
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
sessionChan chan<- tlsSession sessionChan chan<- tlsSession
logger utils.Logger
} }
func newServerTLS( func newServerTLS(
...@@ -47,6 +52,7 @@ func newServerTLS( ...@@ -47,6 +52,7 @@ func newServerTLS(
config *Config, config *Config,
cookieHandler *handshake.CookieHandler, cookieHandler *handshake.CookieHandler,
tlsConf *tls.Config, tlsConf *tls.Config,
logger utils.Logger,
) (*serverTLS, <-chan tlsSession, error) { ) (*serverTLS, <-chan tlsSession, error) {
mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
if err != nil { if err != nil {
...@@ -74,16 +80,17 @@ func newServerTLS( ...@@ -74,16 +80,17 @@ func newServerTLS(
MaxBidiStreams: uint16(config.MaxIncomingStreams), MaxBidiStreams: uint16(config.MaxIncomingStreams),
MaxUniStreams: uint16(config.MaxIncomingUniStreams), MaxUniStreams: uint16(config.MaxIncomingUniStreams),
}, },
logger: logger,
} }
s.newMintConn = s.newMintConnImpl s.newMintConn = s.newMintConnImpl
return s, sessionChan, nil return s, sessionChan, nil
} }
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) { func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
utils.Debugf("Received a Packet. Handling it statelessly.") s.logger.Debugf("Received a Packet. Handling it statelessly.")
sess, err := s.handleInitialImpl(remoteAddr, hdr, data) sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
if err != nil { if err != nil {
utils.Errorf("Error occurred handling initial packet: %s", err) s.logger.Errorf("Error occurred handling initial packet: %s", err)
return return
} }
if sess == nil { // a stateless reset was done if sess == nil { // a stateless reset was done
...@@ -97,7 +104,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data [] ...@@ -97,7 +104,7 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
// will be set to s.newMintConn by the constructor // will be set to s.newMintConn by the constructor
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) { func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v) extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v, s.logger)
conf := s.mintConf.Clone() conf := s.mintConf.Clone()
conf.ExtensionHandler = extHandler conf.ExtensionHandler = extHandler
return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil
...@@ -115,7 +122,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea ...@@ -115,7 +122,7 @@ func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Hea
PacketNumber: 1, // random packet number PacketNumber: 1, // random packet number
Version: clientHdr.Version, Version: clientHdr.Version,
} }
data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer) data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger)
if err != nil { if err != nil {
return err return err
} }
...@@ -129,7 +136,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat ...@@ -129,7 +136,7 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
} }
// check version, if not matching send VNP // check version, if not matching send VNP
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) { if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr) _, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.supportedVersions), remoteAddr)
return nil, err return nil, err
} }
...@@ -139,15 +146,15 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat ...@@ -139,15 +146,15 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat
if err != nil { if err != nil {
return nil, err return nil, err
} }
frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version) frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version)
if err != nil { if err != nil {
utils.Debugf("Error unpacking initial packet: %s", err) s.logger.Debugf("Error unpacking initial packet: %s", err)
return nil, nil return nil, nil
} }
sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
if err != nil { if err != nil {
if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
utils.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr) s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr)
} }
return nil, err return nil, err
} }
...@@ -177,7 +184,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, ...@@ -177,7 +184,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
StreamID: version.CryptoStreamID(), StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(), Data: bc.GetDataForWriting(),
} }
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer) data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -207,6 +214,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, ...@@ -207,6 +214,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
aead, aead,
&params, &params,
version, version,
s.logger,
) )
if err != nil { if err != nil {
return nil, err return nil, err
......
package quic
import "github.com/cheekybits/genny/generic"
// In the auto-generated streams maps, we need to be able to close the streams.
// Therefore, extend the generic.Type with the stream close method.
// This definition must be in a file that Genny doesn't process.
type item interface {
generic.Type
closeForShutdown(error)
}
...@@ -123,6 +123,9 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error { ...@@ -123,6 +123,9 @@ func (m *incomingBidiStreamsMap) DeleteStream(id protocol.StreamID) error {
func (m *incomingBidiStreamsMap) CloseWithError(err error) { func (m *incomingBidiStreamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.mutex.Unlock() m.mutex.Unlock()
m.cond.Broadcast() m.cond.Broadcast()
} }
...@@ -121,6 +121,9 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error { ...@@ -121,6 +121,9 @@ func (m *incomingItemsMap) DeleteStream(id protocol.StreamID) error {
func (m *incomingItemsMap) CloseWithError(err error) { func (m *incomingItemsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.mutex.Unlock() m.mutex.Unlock()
m.cond.Broadcast() m.cond.Broadcast()
} }
...@@ -123,6 +123,9 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error { ...@@ -123,6 +123,9 @@ func (m *incomingUniStreamsMap) DeleteStream(id protocol.StreamID) error {
func (m *incomingUniStreamsMap) CloseWithError(err error) { func (m *incomingUniStreamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.mutex.Unlock() m.mutex.Unlock()
m.cond.Broadcast() m.cond.Broadcast()
} }
...@@ -118,6 +118,9 @@ func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) { ...@@ -118,6 +118,9 @@ func (m *outgoingBidiStreamsMap) SetMaxStream(id protocol.StreamID) {
func (m *outgoingBidiStreamsMap) CloseWithError(err error) { func (m *outgoingBidiStreamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.cond.Broadcast() m.cond.Broadcast()
m.mutex.Unlock() m.mutex.Unlock()
} }
...@@ -4,14 +4,11 @@ import ( ...@@ -4,14 +4,11 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/cheekybits/genny/generic"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type item generic.Type
//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream" //go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream"
//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream" //go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream"
type outgoingItemsMap struct { type outgoingItemsMap struct {
...@@ -119,6 +116,9 @@ func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) { ...@@ -119,6 +116,9 @@ func (m *outgoingItemsMap) SetMaxStream(id protocol.StreamID) {
func (m *outgoingItemsMap) CloseWithError(err error) { func (m *outgoingItemsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.cond.Broadcast() m.cond.Broadcast()
m.mutex.Unlock() m.mutex.Unlock()
} }
...@@ -118,6 +118,9 @@ func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) { ...@@ -118,6 +118,9 @@ func (m *outgoingUniStreamsMap) SetMaxStream(id protocol.StreamID) {
func (m *outgoingUniStreamsMap) CloseWithError(err error) { func (m *outgoingUniStreamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
m.closeErr = err m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
m.cond.Broadcast() m.cond.Broadcast()
m.mutex.Unlock() m.mutex.Unlock()
} }
...@@ -265,7 +265,7 @@ func (c *Client) QueryRegistration() (*RegistrationResource, error) { ...@@ -265,7 +265,7 @@ func (c *Client) QueryRegistration() (*RegistrationResource, error) {
// your issued certificate as a bundle. // your issued certificate as a bundle.
// This function will never return a partial certificate. If one domain in the list fails, // This function will never return a partial certificate. If one domain in the list fails,
// the whole certificate will fail. // the whole certificate will fail.
func (c *Client) ObtainCertificateForCSR(csr x509.CertificateRequest, bundle bool) (CertificateResource, map[string]error) { func (c *Client) ObtainCertificateForCSR(csr x509.CertificateRequest, bundle bool) (CertificateResource, error) {
// figure out what domains it concerns // figure out what domains it concerns
// start with the common name // start with the common name
domains := []string{csr.Subject.CommonName} domains := []string{csr.Subject.CommonName}
...@@ -292,11 +292,7 @@ DNSNames: ...@@ -292,11 +292,7 @@ DNSNames:
order, err := c.createOrderForIdentifiers(domains) order, err := c.createOrderForIdentifiers(domains)
if err != nil { if err != nil {
identErrors := make(map[string]error) return CertificateResource{}, err
for _, auth := range order.Identifiers {
identErrors[auth.Value] = err
}
return CertificateResource{}, identErrors
} }
authz, failures := c.getAuthzForOrder(order) authz, failures := c.getAuthzForOrder(order)
// If any challenge fails - return. Do not generate partial SAN certificates. // If any challenge fails - return. Do not generate partial SAN certificates.
...@@ -338,7 +334,11 @@ DNSNames: ...@@ -338,7 +334,11 @@ DNSNames:
// your issued certificate as a bundle. // your issued certificate as a bundle.
// This function will never return a partial certificate. If one domain in the list fails, // This function will never return a partial certificate. If one domain in the list fails,
// the whole certificate will fail. // the whole certificate will fail.
func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (CertificateResource, map[string]error) { func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (CertificateResource, error) {
if len(domains) == 0 {
return CertificateResource{}, errors.New("Passed no domains into ObtainCertificate")
}
if bundle { if bundle {
logf("[INFO][%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", ")) logf("[INFO][%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", "))
} else { } else {
...@@ -347,11 +347,7 @@ func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto ...@@ -347,11 +347,7 @@ func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto
order, err := c.createOrderForIdentifiers(domains) order, err := c.createOrderForIdentifiers(domains)
if err != nil { if err != nil {
identErrors := make(map[string]error) return CertificateResource{}, err
for _, auth := range order.Identifiers {
identErrors[auth.Value] = err
}
return CertificateResource{}, identErrors
} }
authz, failures := c.getAuthzForOrder(order) authz, failures := c.getAuthzForOrder(order)
// If any challenge fails - return. Do not generate partial SAN certificates. // If any challenge fails - return. Do not generate partial SAN certificates.
...@@ -433,7 +429,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b ...@@ -433,7 +429,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b
return CertificateResource{}, err return CertificateResource{}, err
} }
newCert, failures := c.ObtainCertificateForCSR(*csr, bundle) newCert, failures := c.ObtainCertificateForCSR(*csr, bundle)
return newCert, failures[cert.Domain] return newCert, failures
} }
var privKey crypto.PrivateKey var privKey crypto.PrivateKey
...@@ -445,7 +441,6 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b ...@@ -445,7 +441,6 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b
} }
var domains []string var domains []string
var failures map[string]error
// check for SAN certificate // check for SAN certificate
if len(x509Cert.DNSNames) > 1 { if len(x509Cert.DNSNames) > 1 {
domains = append(domains, x509Cert.Subject.CommonName) domains = append(domains, x509Cert.Subject.CommonName)
...@@ -459,8 +454,8 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b ...@@ -459,8 +454,8 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple b
domains = append(domains, x509Cert.Subject.CommonName) domains = append(domains, x509Cert.Subject.CommonName)
} }
newCert, failures := c.ObtainCertificate(domains, bundle, privKey, mustStaple) newCert, err := c.ObtainCertificate(domains, bundle, privKey, mustStaple)
return newCert, failures[cert.Domain] return newCert, err
} }
func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, error) { func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, error) {
...@@ -482,6 +477,7 @@ func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, err ...@@ -482,6 +477,7 @@ func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, err
orderRes := orderResource{ orderRes := orderResource{
URL: hdr.Get("Location"), URL: hdr.Get("Location"),
Domains: domains,
orderMessage: response, orderMessage: response,
} }
return orderRes, nil return orderRes, nil
...@@ -489,9 +485,9 @@ func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, err ...@@ -489,9 +485,9 @@ func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, err
// Looks through the challenge combinations to find a solvable match. // Looks through the challenge combinations to find a solvable match.
// Then solves the challenges in series and returns. // Then solves the challenges in series and returns.
func (c *Client) solveChallengeForAuthz(authorizations []authorization) map[string]error { func (c *Client) solveChallengeForAuthz(authorizations []authorization) ObtainError {
// loop through the resources, basically through the domains. // loop through the resources, basically through the domains.
failures := make(map[string]error) failures := make(ObtainError)
for _, authz := range authorizations { for _, authz := range authorizations {
if authz.Status == "valid" { if authz.Status == "valid" {
// Boulder might recycle recent validated authz (see issue #267) // Boulder might recycle recent validated authz (see issue #267)
...@@ -527,7 +523,7 @@ func (c *Client) chooseSolver(auth authorization, domain string) (int, solver) { ...@@ -527,7 +523,7 @@ func (c *Client) chooseSolver(auth authorization, domain string) (int, solver) {
} }
// Get the challenges needed to proof our identifier to the ACME server. // Get the challenges needed to proof our identifier to the ACME server.
func (c *Client) getAuthzForOrder(order orderResource) ([]authorization, map[string]error) { func (c *Client) getAuthzForOrder(order orderResource) ([]authorization, ObtainError) {
resc, errc := make(chan authorization), make(chan domainError) resc, errc := make(chan authorization), make(chan domainError)
delay := time.Second / overallRequestLimit delay := time.Second / overallRequestLimit
...@@ -590,7 +586,7 @@ func (c *Client) requestCertificateForOrder(order orderResource, bundle bool, pr ...@@ -590,7 +586,7 @@ func (c *Client) requestCertificateForOrder(order orderResource, bundle bool, pr
} }
// determine certificate name(s) based on the authorization resources // determine certificate name(s) based on the authorization resources
commonName := order.Identifiers[0].Value commonName := order.Domains[0]
var san []string var san []string
for _, auth := range order.Identifiers { for _, auth := range order.Identifiers {
san = append(san, auth.Value) san = append(san, auth.Value)
...@@ -606,12 +602,7 @@ func (c *Client) requestCertificateForOrder(order orderResource, bundle bool, pr ...@@ -606,12 +602,7 @@ func (c *Client) requestCertificateForOrder(order orderResource, bundle bool, pr
} }
func (c *Client) requestCertificateForCsr(order orderResource, bundle bool, csr []byte, privateKeyPem []byte) (CertificateResource, error) { func (c *Client) requestCertificateForCsr(order orderResource, bundle bool, csr []byte, privateKeyPem []byte) (CertificateResource, error) {
commonName := order.Identifiers[0].Value commonName := order.Domains[0]
var authURLs []string
for _, auth := range order.Identifiers[1:] {
authURLs = append(authURLs, auth.Value)
}
csrString := base64.RawURLEncoding.EncodeToString(csr) csrString := base64.RawURLEncoding.EncodeToString(csr)
var retOrder orderMessage var retOrder orderMessage
......
package acme package acme
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -42,6 +43,18 @@ type domainError struct { ...@@ -42,6 +43,18 @@ type domainError struct {
Error error Error error
} }
// ObtainError is returned when there are specific errors available
// per domain. For example in ObtainCertificate
type ObtainError map[string]error
func (e ObtainError) Error() string {
buffer := bytes.NewBufferString("acme: Error -> One or more domains had a problem:\n")
for dom, err := range e {
buffer.WriteString(fmt.Sprintf("[%s] %s\n", dom, err))
}
return buffer.String()
}
func handleHTTPError(resp *http.Response) error { func handleHTTPError(resp *http.Response) error {
var errorDetail RemoteError var errorDetail RemoteError
......
...@@ -35,6 +35,7 @@ type accountMessage struct { ...@@ -35,6 +35,7 @@ type accountMessage struct {
type orderResource struct { type orderResource struct {
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
Domains []string `json:"domains,omitempty"`
orderMessage `json:"body,omitempty"` orderMessage `json:"body,omitempty"`
} }
......
...@@ -129,7 +129,7 @@ ...@@ -129,7 +129,7 @@
"importpath": "github.com/lucas-clemente/aes12", "importpath": "github.com/lucas-clemente/aes12",
"repository": "https://github.com/lucas-clemente/aes12", "repository": "https://github.com/lucas-clemente/aes12",
"vcs": "git", "vcs": "git",
"revision": "25700e67be5c860bcc999137275b9ef8b65932bd", "revision": "cd47fb39b79f867c6e4e5cd39cf7abd799f71670",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
...@@ -145,7 +145,7 @@ ...@@ -145,7 +145,7 @@
"importpath": "github.com/lucas-clemente/quic-go", "importpath": "github.com/lucas-clemente/quic-go",
"repository": "https://github.com/lucas-clemente/quic-go", "repository": "https://github.com/lucas-clemente/quic-go",
"vcs": "git", "vcs": "git",
"revision": "9fa739409e6edddbbd47c8031cb7bb3d1a209cc8", "revision": "da7708e47066ab0aff0f20f66b21c1f329db1eff",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
...@@ -193,7 +193,7 @@ ...@@ -193,7 +193,7 @@
"importpath": "github.com/xenolf/lego/acmev2", "importpath": "github.com/xenolf/lego/acmev2",
"repository": "https://github.com/xenolf/lego", "repository": "https://github.com/xenolf/lego",
"vcs": "git", "vcs": "git",
"revision": "805eec97569ff533e1b75b16eac0bdd94e67bdd6", "revision": "6e962fbfb37f9ea4a8201e32acb1b94ffb3b8398",
"branch": "acmev2", "branch": "acmev2",
"path": "/acmev2", "path": "/acmev2",
"notests": true "notests": true
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment