Commit 590862a9 authored by Tw's avatar Tw

replacer: capture request body normally

fix issue #1015
Signed-off-by: default avatarTw <tw19881113@gmail.com>
parent 40c09d67
...@@ -44,6 +44,36 @@ type replacer struct { ...@@ -44,6 +44,36 @@ type replacer struct {
emptyValue string emptyValue string
responseRecorder *ResponseRecorder responseRecorder *ResponseRecorder
request *http.Request request *http.Request
requestBody *limitWriter
}
type limitWriter struct {
w bytes.Buffer
remain int
}
func newLimitWriter(max int) *limitWriter {
return &limitWriter{
w: bytes.Buffer{},
remain: max,
}
}
func (lw *limitWriter) Write(p []byte) (int, error) {
// skip if we are full
if lw.remain <= 0 {
return len(p), nil
}
if n := len(p); n > lw.remain {
p = p[:lw.remain]
}
n, err := lw.w.Write(p)
lw.remain -= n
return n, err
}
func (lw *limitWriter) String() string {
return lw.w.String()
} }
// NewReplacer makes a new replacer based on r and rr which // NewReplacer makes a new replacer based on r and rr which
...@@ -54,8 +84,16 @@ type replacer struct { ...@@ -54,8 +84,16 @@ type replacer struct {
// emptyValue should be the string that is used in place // emptyValue should be the string that is used in place
// of empty string (can still be empty string). // of empty string (can still be empty string).
func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer { func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
rb := newLimitWriter(MaxLogBodySize)
if r.Body != nil {
r.Body = struct {
io.Reader
io.Closer
}{io.TeeReader(r.Body, rb), io.Closer(r.Body)}
}
rep := &replacer{ rep := &replacer{
request: r, request: r,
requestBody: rb,
responseRecorder: rr, responseRecorder: rr,
customReplacements: make(map[string]string), customReplacements: make(map[string]string),
emptyValue: emptyValue, emptyValue: emptyValue,
...@@ -81,27 +119,6 @@ func canLogRequest(r *http.Request) bool { ...@@ -81,27 +119,6 @@ func canLogRequest(r *http.Request) bool {
return false return false
} }
// readRequestBody reads the request body and sets a
// new io.ReadCloser that has not yet been read.
func readRequestBody(r *http.Request, n int64) ([]byte, error) {
defer r.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(r.Body, n))
if err != nil {
return nil, err
}
// Read the remaining bytes
remaining, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(append(body, remaining...))
r.Body = ioutil.NopCloser(buf)
return body, nil
}
// Replace performs a replacement of values on s and returns // Replace performs a replacement of values on s and returns
// the string with the replaced values. // the string with the replaced values.
func (r *replacer) Replace(s string) string { func (r *replacer) Replace(s string) string {
...@@ -249,11 +266,11 @@ func (r *replacer) getSubstitution(key string) string { ...@@ -249,11 +266,11 @@ func (r *replacer) getSubstitution(key string) string {
if !canLogRequest(r.request) { if !canLogRequest(r.request) {
return r.emptyValue return r.emptyValue
} }
body, err := readRequestBody(r.request, maxLogBodySize) _, err := ioutil.ReadAll(r.request.Body)
if err != nil { if err != nil {
return r.emptyValue return r.emptyValue
} }
return requestReplacer.Replace(string(body)) return requestReplacer.Replace(r.requestBody.String())
case "{status}": case "{status}":
if r.responseRecorder == nil { if r.responseRecorder == nil {
return r.emptyValue return r.emptyValue
......
package httpserver package httpserver
import ( import (
"bytes"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
...@@ -149,28 +147,3 @@ func TestRound(t *testing.T) { ...@@ -149,28 +147,3 @@ func TestRound(t *testing.T) {
} }
} }
} }
func TestReadRequestBody(t *testing.T) {
payload := []byte(`{ "foo": "bar" }`)
var readSize int64 = 5
r, err := http.NewRequest("POST", "/", bytes.NewReader(payload))
if err != nil {
t.Error(err)
}
defer r.Body.Close()
logBody, err := readRequestBody(r, readSize)
if err != nil {
t.Error("readRequestBody failed", err)
} else if !bytes.EqualFold(payload[0:readSize], logBody) {
t.Error("Expected log comparison failed")
}
// Ensure the Request body is the same as the original.
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error("Unable to read request body", err)
} else if !bytes.EqualFold(payload, reqBody) {
t.Error("Expected request body comparison failed")
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment