Commit 1913c7d9 authored by Nick Thomas's avatar Nick Thomas

Avoid setting multiple values for certain headers

parent b598d290
......@@ -138,14 +138,14 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string
func setArchiveHeaders(w http.ResponseWriter, format string, archiveFilename string) {
w.Header().Del("Content-Length")
w.Header().Add("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename))
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename))
if format == "zip" {
w.Header().Add("Content-Type", "application/zip")
w.Header().Set("Content-Type", "application/zip")
} else {
w.Header().Add("Content-Type", "application/octet-stream")
w.Header().Set("Content-Type", "application/octet-stream")
}
w.Header().Add("Content-Transfer-Encoding", "binary")
w.Header().Add("Cache-Control", "private")
w.Header().Set("Content-Transfer-Encoding", "binary")
w.Header().Set("Cache-Control", "private")
}
func parseArchiveFormat(format string) (*exec.Cmd, string) {
......
......@@ -2,7 +2,10 @@ package git
import (
"io/ioutil"
"net/http/httptest"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
)
func TestParseBasename(t *testing.T) {
......@@ -42,3 +45,27 @@ func TestFinalizeArchive(t *testing.T) {
t.Fatalf("expected nil from finalizeCachedArchive, received %v", err)
}
}
func TestSetArchiveHeaders(t *testing.T) {
for _, testCase := range []struct{ in, out string }{
{"zip", "application/zip"},
{"zippy", "application/octet-stream"},
{"rezip", "application/octet-stream"},
{"_anything_", "application/octet-stream"},
} {
w := httptest.NewRecorder()
// These should be replaced, not appended to
w.Header().Set("Content-Type", "test")
w.Header().Set("Content-Length", "test")
w.Header().Set("Content-Disposition", "test")
w.Header().Set("Cache-Control", "test")
setArchiveHeaders(w, testCase.in, "filename")
testhelper.AssertResponseHeader(t, w, "Content-Type", testCase.out)
testhelper.AssertResponseHeader(t, w, "Content-Length")
testhelper.AssertResponseHeader(t, w, "Content-Disposition", `attachment; filename="filename"`)
testhelper.AssertResponseHeader(t, w, "Cache-Control", "private")
}
}
......@@ -77,8 +77,8 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response)
defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
w.Header().Add("Cache-Control", "no-cache")
w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
helper.LogError(r, fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
......@@ -164,8 +164,8 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
r.Body.Close()
// Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", action))
w.Header().Add("Cache-Control", "no-cache")
w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", action))
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
// This io.Copy may take a long time, both for Git push and pull.
......
......@@ -116,7 +116,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
}
} else {
testhelper.AssertResponseCode(t, w, 200)
testhelper.AssertResponseHeader(t, w, "Content-Encoding", "")
testhelper.AssertResponseHeader(t, w, "Content-Encoding")
if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String())
}
......
......@@ -63,9 +63,17 @@ func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expec
}
}
func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) {
if response.Header().Get(header) != expectedValue {
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
func AssertResponseHeader(t *testing.T, w http.ResponseWriter, header string, expected ...string) {
actual := w.Header()[http.CanonicalHeaderKey(header)]
if len(expected) != len(actual) {
t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
}
for i, value := range expected {
if value != actual[i] {
t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment