Commit a3fc1506 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Assert internal header removal

parent b6dcb7c8
......@@ -14,6 +14,8 @@ import (
"strings"
)
const sendDataHeader = "Gitlab-Workhorse-Send-Data"
type sendFileResponseWriter struct {
rw http.ResponseWriter
status int
......@@ -45,13 +47,6 @@ func (s *sendFileResponseWriter) Write(data []byte) (n int, err error) {
}
func (s *sendFileResponseWriter) WriteHeader(status int) {
// Never pass these headers to the client
defer func() {
s.Header().Del("X-Sendfile")
s.Header().Del("Gitlab-Workhorse-Send-Data")
}()
if s.status != 0 {
return
}
......@@ -63,6 +58,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
}
if file := s.Header().Get("X-Sendfile"); file != "" {
s.Header().Del("X-Sendfile")
// Mark this connection as hijacked
s.hijacked = true
......@@ -70,7 +66,8 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
sendFileFromDisk(s.rw, s.req, file)
return
}
if sendData := s.Header().Get("Gitlab-Workhorse-Send-Data"); strings.HasPrefix(sendData, git.SendBlobPrefix) {
if sendData := s.Header().Get(sendDataHeader); strings.HasPrefix(sendData, git.SendBlobPrefix) {
s.Header().Del(sendDataHeader)
s.hijacked = true
git.SendBlob(s.rw, s.req, sendData)
return
......
......@@ -558,10 +558,11 @@ func TestArtifactsGetSingleFile(t *testing.T) {
func TestGetGitBlob(t *testing.T) {
blobId := "50b27c6518be44c42c4d87966ae2481ce895624c" // the LICENSE file in the test repository
blobLength := 1075
headerKey := http.CanonicalHeaderKey("Gitlab-Workhorse-Send-Data")
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
responseJSON := fmt.Sprintf(`{"RepoPath":"%s","BlobId":"%s"}`, path.Join(testRepoRoot, testRepo), blobId)
encodedJSON := base64.StdEncoding.EncodeToString([]byte(responseJSON))
w.Header().Set("Gitlab-Workhorse-Send-Data", "git-blob:"+encodedJSON)
w.Header().Set(headerKey, "git-blob:"+encodedJSON)
// Prevent the Go HTTP server from setting the Content-Length to 0.
w.Header().Set("Transfer-Encoding", "chunked")
if _, err := fmt.Fprintf(w, "GNU General Public License"); err != nil {
......@@ -581,6 +582,9 @@ func TestGetGitBlob(t *testing.T) {
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resourcePath, resp.StatusCode)
}
if len(resp.Header[headerKey]) != 0 {
t.Fatalf("Unexpected response header: %s: %q", headerKey, resp.Header.Get(headerKey))
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment