Commit 8a1dcaa6 authored by Erick Bajao's avatar Erick Bajao

Improve readability of content type override test

parent 68de0104
...@@ -177,35 +177,38 @@ func TestSetProperContentTypeAndDisposition(t *testing.T) { ...@@ -177,35 +177,38 @@ func TestSetProperContentTypeAndDisposition(t *testing.T) {
func TestFailOverrideContentType(t *testing.T) { func TestFailOverrideContentType(t *testing.T) {
testCases := []struct { testCases := []struct {
contentTypeOverride string desc string
contentType string overrideFromUpstream string
body string responseContentType string
body string
}{ }{
{ {
contentType: "text/plain; charset=utf-8", desc: "Force text/html into text/plain",
contentTypeOverride: "text/html; charset=utf-8", responseContentType: "text/plain; charset=utf-8",
body: "<html><body>Hello world!</body></html>", overrideFromUpstream: "text/html; charset=utf-8",
body: "<html><body>Hello world!</body></html>",
}, },
{ {
contentType: "text/plain; charset=utf-8", desc: "Force application/javascript into text/plain",
contentTypeOverride: "application/javascript; charset=utf-8", responseContentType: "text/plain; charset=utf-8",
body: "alert(1);", overrideFromUpstream: "application/javascript; charset=utf-8",
body: "alert(1);",
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.contentTypeOverride, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// We are pretending to be upstream or an inner layer of the ResponseWriter chain // We are pretending to be upstream or an inner layer of the ResponseWriter chain
w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true")
w.Header().Set(headers.ContentTypeHeader, tc.contentTypeOverride) w.Header().Set(headers.ContentTypeHeader, tc.overrideFromUpstream)
_, err := io.WriteString(w, tc.body) _, err := io.WriteString(w, tc.body)
require.NoError(t, err) require.NoError(t, err)
}) })
resp := makeRequest(t, h, tc.body, "") resp := makeRequest(t, h, tc.body, "")
require.Equal(t, tc.contentType, resp.Header.Get(headers.ContentTypeHeader)) require.Equal(t, tc.responseContentType, resp.Header.Get(headers.ContentTypeHeader))
}) })
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment