Commit 2019eec5 authored by Matthew Holt's avatar Matthew Holt

Fix lint warnings; group methods for same type together

parent 33d10339
...@@ -144,7 +144,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -144,7 +144,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
return rp return rp
} }
// InsecureTransport is used to facilitate HTTPS proxying // UseInsecureTransport is used to facilitate HTTPS proxying
// when it is OK for upstream to be using a bad certificate, // when it is OK for upstream to be using a bad certificate,
// since this transport skips verification. // since this transport skips verification.
func (rp *ReverseProxy) UseInsecureTransport() { func (rp *ReverseProxy) UseInsecureTransport() {
...@@ -163,6 +163,95 @@ func (rp *ReverseProxy) UseInsecureTransport() { ...@@ -163,6 +163,95 @@ func (rp *ReverseProxy) UseInsecureTransport() {
} }
} }
// ServeHTTP serves the proxied request to the upstream by performing a roundtrip.
// It is designed to handle websocket connection upgrades as well.
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := rp.Transport
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
}
rp.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
}
if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close()
hj, ok := rw.(http.Hijacker)
if !ok {
return nil
}
conn, _, err := hj.Hijack()
if err != nil {
return err
}
defer conn.Close()
var backendConn net.Conn
if hj, ok := transport.(*connHijackerTransport); ok {
backendConn = hj.Conn
if _, err := conn.Write(hj.Replay); err != nil {
return err
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
if err != nil {
return err
}
outreq.Write(backendConn)
}
defer backendConn.Close()
go func() {
io.Copy(backendConn, conn) // write tcp stream to backend.
}()
io.Copy(conn, backendConn) // read tcp stream from backend.
} else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
rp.copyResponse(rw, res.Body)
}
return nil
}
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if rp.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: rp.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.CopyBuffer(dst, src, buf.([]byte))
}
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {
for k, vv := range src { for k, vv := range src {
for _, v := range vv { for _, v := range vv {
...@@ -255,93 +344,6 @@ func requestIsWebsocket(req *http.Request) bool { ...@@ -255,93 +344,6 @@ func requestIsWebsocket(req *http.Request) bool {
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
}
p.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
}
if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close()
hj, ok := rw.(http.Hijacker)
if !ok {
return nil
}
conn, _, err := hj.Hijack()
if err != nil {
return err
}
defer conn.Close()
var backendConn net.Conn
if hj, ok := transport.(*connHijackerTransport); ok {
backendConn = hj.Conn
if _, err := conn.Write(hj.Replay); err != nil {
return err
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
if err != nil {
return err
}
outreq.Write(backendConn)
}
defer backendConn.Close()
go func() {
io.Copy(backendConn, conn) // write tcp stream to backend.
}()
io.Copy(conn, backendConn) // read tcp stream from backend.
} else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
p.copyResponse(rw, res.Body)
}
return nil
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: p.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
io.CopyBuffer(dst, src, buf.([]byte))
}
type writeFlusher interface { type writeFlusher interface {
io.Writer io.Writer
http.Flusher http.Flusher
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment