Commit 1913c7d9 authored by Nick Thomas's avatar Nick Thomas

Avoid setting multiple values for certain headers

parent b598d290
...@@ -138,14 +138,14 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string ...@@ -138,14 +138,14 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string
func setArchiveHeaders(w http.ResponseWriter, format string, archiveFilename string) { func setArchiveHeaders(w http.ResponseWriter, format string, archiveFilename string) {
w.Header().Del("Content-Length") w.Header().Del("Content-Length")
w.Header().Add("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename)) w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename))
if format == "zip" { if format == "zip" {
w.Header().Add("Content-Type", "application/zip") w.Header().Set("Content-Type", "application/zip")
} else { } else {
w.Header().Add("Content-Type", "application/octet-stream") w.Header().Set("Content-Type", "application/octet-stream")
} }
w.Header().Add("Content-Transfer-Encoding", "binary") w.Header().Set("Content-Transfer-Encoding", "binary")
w.Header().Add("Cache-Control", "private") w.Header().Set("Cache-Control", "private")
} }
func parseArchiveFormat(format string) (*exec.Cmd, string) { func parseArchiveFormat(format string) (*exec.Cmd, string) {
......
...@@ -2,7 +2,10 @@ package git ...@@ -2,7 +2,10 @@ package git
import ( import (
"io/ioutil" "io/ioutil"
"net/http/httptest"
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
) )
func TestParseBasename(t *testing.T) { func TestParseBasename(t *testing.T) {
...@@ -42,3 +45,27 @@ func TestFinalizeArchive(t *testing.T) { ...@@ -42,3 +45,27 @@ func TestFinalizeArchive(t *testing.T) {
t.Fatalf("expected nil from finalizeCachedArchive, received %v", err) t.Fatalf("expected nil from finalizeCachedArchive, received %v", err)
} }
} }
func TestSetArchiveHeaders(t *testing.T) {
for _, testCase := range []struct{ in, out string }{
{"zip", "application/zip"},
{"zippy", "application/octet-stream"},
{"rezip", "application/octet-stream"},
{"_anything_", "application/octet-stream"},
} {
w := httptest.NewRecorder()
// These should be replaced, not appended to
w.Header().Set("Content-Type", "test")
w.Header().Set("Content-Length", "test")
w.Header().Set("Content-Disposition", "test")
w.Header().Set("Cache-Control", "test")
setArchiveHeaders(w, testCase.in, "filename")
testhelper.AssertResponseHeader(t, w, "Content-Type", testCase.out)
testhelper.AssertResponseHeader(t, w, "Content-Length")
testhelper.AssertResponseHeader(t, w, "Content-Disposition", `attachment; filename="filename"`)
testhelper.AssertResponseHeader(t, w, "Cache-Control", "private")
}
}
...@@ -77,8 +77,8 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) ...@@ -77,8 +77,8 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response)
defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc))
w.Header().Add("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
helper.LogError(r, fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) helper.LogError(r, fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
...@@ -164,8 +164,8 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) { ...@@ -164,8 +164,8 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) {
r.Body.Close() r.Body.Close()
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", action)) w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", action))
w.Header().Add("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
// This io.Copy may take a long time, both for Git push and pull. // This io.Copy may take a long time, both for Git push and pull.
......
...@@ -116,7 +116,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -116,7 +116,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
} }
} else { } else {
testhelper.AssertResponseCode(t, w, 200) testhelper.AssertResponseCode(t, w, 200)
testhelper.AssertResponseHeader(t, w, "Content-Encoding", "") testhelper.AssertResponseHeader(t, w, "Content-Encoding")
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
} }
......
...@@ -63,9 +63,17 @@ func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expec ...@@ -63,9 +63,17 @@ func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expec
} }
} }
func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) { func AssertResponseHeader(t *testing.T, w http.ResponseWriter, header string, expected ...string) {
if response.Header().Get(header) != expectedValue { actual := w.Header()[http.CanonicalHeaderKey(header)]
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
if len(expected) != len(actual) {
t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
}
for i, value := range expected {
if value != actual[i] {
t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment