Commit cf40f383 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'accelerate' into 'master'

Accelerate multipart file uploads for all requests

Companion to https://gitlab.com/gitlab-org/gitlab-ce/merge_requests/5867

Extends the upload mechanism for CI artifacts to all multipart file uploads.

See merge request !58
parents d01ee210 9c5c1d5f
......@@ -10,6 +10,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/dgrijalva/jwt-go"
......@@ -32,7 +33,8 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur
t.Fatal(err)
}
parsedURL := helper.URLMustParse(ts.URL)
a := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), badgateway.TestRoundTripper(parsedURL))
testhelper.ConfigureSecret()
a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
response := httptest.NewRecorder()
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
......@@ -86,7 +88,8 @@ func TestPreAuthorizeJWT(t *testing.T) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
secretBytes, err := (&api.Secret{Path: testhelper.SecretPath()}).Bytes()
testhelper.ConfigureSecret()
secretBytes, err := secret.Bytes()
if err != nil {
return nil, fmt.Errorf("read secret from file: %v", err)
}
......
......@@ -11,8 +11,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"github.com/dgrijalva/jwt-go"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
)
// Custom content type for API responses, to catch routing / programming mistakes
......@@ -24,15 +23,13 @@ type API struct {
Client *http.Client
URL *url.URL
Version string
Secret *Secret
}
func NewAPI(myURL *url.URL, version, secretPath string, roundTripper *badgateway.RoundTripper) *API {
func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
return &API{
Client: &http.Client{Transport: roundTripper},
URL: myURL,
Version: version,
Secret: &Secret{Path: secretPath},
}
}
......@@ -130,13 +127,7 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt
// configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", api.Version)
secretBytes, err := api.Secret.Bytes()
if err != nil {
return nil, fmt.Errorf("newRequest: %v", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{Issuer: "gitlab-workhorse"})
tokenString, err := token.SignedString(secretBytes)
tokenString, err := secret.JWTTokenString(secret.DefaultClaims)
if err != nil {
return nil, fmt.Errorf("newRequest: sign JWT: %v", err)
}
......
......@@ -65,6 +65,10 @@ func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipa
return nil
}
func (a *artifactsUploadProcessor) Finalize() error {
return nil
}
func (a *artifactsUploadProcessor) Cleanup() {
if a.metadataFile != "" {
os.Remove(a.metadataFile)
......
......@@ -93,7 +93,8 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h
response := httptest.NewRecorder()
parsedURL := helper.URLMustParse(ts.URL)
roundTripper := badgateway.TestRoundTripper(parsedURL)
apiClient := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), roundTripper)
testhelper.ConfigureSecret()
apiClient := api.NewAPI(parsedURL, "123", roundTripper)
proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response
......
package secret
import (
"fmt"
"github.com/dgrijalva/jwt-go"
)
var (
DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"}
)
func JWTTokenString(claims jwt.Claims) (string, error) {
secretBytes, err := Bytes()
if err != nil {
return "", fmt.Errorf("secret.JWTTokenString: %v", err)
}
tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes)
if err != nil {
return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err)
}
return tokenString, nil
}
package api
package secret
import (
"encoding/base64"
......@@ -9,52 +9,69 @@ import (
const numSecretBytes = 32
type Secret struct {
Path string
type sec struct {
path string
bytes []byte
sync.RWMutex
}
var (
theSecret = &sec{}
)
func SetPath(path string) {
theSecret.Lock()
defer theSecret.Unlock()
theSecret.path = path
theSecret.bytes = nil
}
// Lazy access to the HMAC secret key. We must be lazy because if the key
// is not already there, it will be generated by gitlab-rails, and
// gitlab-rails is slow.
func (s *Secret) Bytes() ([]byte, error) {
if bytes := s.getBytes(); bytes != nil {
return bytes, nil
func Bytes() ([]byte, error) {
if bytes := getBytes(); bytes != nil {
return copyBytes(bytes), nil
}
return s.setBytes()
return setBytes()
}
func getBytes() []byte {
theSecret.RLock()
defer theSecret.RUnlock()
return theSecret.bytes
}
func (s *Secret) getBytes() []byte {
s.RLock()
defer s.RUnlock()
return s.bytes
func copyBytes(bytes []byte) []byte {
out := make([]byte, len(bytes))
copy(out, bytes)
return out
}
func (s *Secret) setBytes() ([]byte, error) {
s.Lock()
defer s.Unlock()
func setBytes() ([]byte, error) {
theSecret.Lock()
defer theSecret.Unlock()
if s.bytes != nil {
return s.bytes, nil
if theSecret.bytes != nil {
return theSecret.bytes, nil
}
base64Bytes, err := ioutil.ReadFile(s.Path)
base64Bytes, err := ioutil.ReadFile(theSecret.path)
if err != nil {
return nil, fmt.Errorf("read Secret.Path: %v", err)
return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err)
}
secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes)))
n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes)
if err != nil {
return nil, fmt.Errorf("decode secret: %v", err)
return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err)
}
if n != numSecretBytes {
return nil, fmt.Errorf("expected %d secretBytes in %s, found %d", numSecretBytes, s.Path, n)
return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n)
}
s.bytes = secretBytes
return s.bytes, nil
theSecret.bytes = secretBytes
return copyBytes(theSecret.bytes), nil
}
......@@ -16,10 +16,12 @@ import (
"runtime"
"strings"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
)
func SecretPath() string {
return path.Join(RootDir(), "testdata/test-secret")
func ConfigureSecret() {
secret.SetPath(path.Join(RootDir(), "testdata/test-secret"))
}
var extractPatchSeriesMatcher = regexp.MustCompile(`^From (\w+)`)
......
package upload
import (
"fmt"
"mime/multipart"
"net/http"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"github.com/dgrijalva/jwt-go"
)
const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields"
type savedFileTracker struct {
request *http.Request
rewrittenFields map[string]string
}
type MultipartClaims struct {
RewrittenFields map[string]string `json:"rewritten_fields"`
jwt.StandardClaims
}
func Accelerate(tempDir string, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s := &savedFileTracker{request: r}
HandleFileUploads(w, r, h, tempDir, s)
})
}
func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart.Writer) error {
if s.rewrittenFields == nil {
s.rewrittenFields = make(map[string]string)
}
s.rewrittenFields[fieldName] = fileName
return nil
}
func (_ *savedFileTracker) ProcessField(_ string, _ *multipart.Writer) error {
return nil
}
func (s *savedFileTracker) Finalize() error {
if s.rewrittenFields == nil {
return nil
}
claims := MultipartClaims{s.rewrittenFields, secret.DefaultClaims}
tokenString, err := secret.JWTTokenString(claims)
if err != nil {
return fmt.Errorf("savedFileTracker.Finalize: %v", err)
}
s.request.Header.Set(RewrittenFieldsHeader, tokenString)
return nil
}
......@@ -8,13 +8,17 @@ import (
"mime/multipart"
"net/http"
"os"
"path"
"strings"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
// These methods are allowed to have thread-unsafe implementations.
type MultipartFormProcessor interface {
ProcessFile(formName, fileName string, writer *multipart.Writer) error
ProcessField(formName string, writer *multipart.Writer) error
Finalize() error
}
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) {
......@@ -28,11 +32,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
return nil, fmt.Errorf("get multipart reader: %v", err)
}
var files []string
var directories []string
cleanup = func() {
for _, file := range files {
os.Remove(file)
for _, dir := range directories {
os.RemoveAll(dir)
}
}
......@@ -56,22 +60,30 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
// Copy form field
if filename := p.FileName(); filename != "" {
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return cleanup, fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, fmt.Errorf("mkdir for tempfile: %v", err)
}
// Create temporary file in path returned by Authorization filter
file, err := ioutil.TempFile(tempPath, "upload_")
tempDir, err := ioutil.TempDir(tempPath, "multipart-")
if err != nil {
return cleanup, fmt.Errorf("create tempdir: %v", err)
}
directories = append(directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return cleanup, fmt.Errorf("create tempfile: %v", err)
return cleanup, fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
writer.WriteField(name+".path", file.Name())
writer.WriteField(name+".name", filename)
files = append(files, file.Name())
_, err = io.Copy(file, p)
if err != nil {
......@@ -135,6 +147,11 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t
r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType())
if err := filter.Finalize(); err != nil {
helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err))
return
}
// Proxy the request
h.ServeHTTP(w, r)
}
......@@ -22,8 +22,7 @@ import (
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
type testFormProcessor struct {
}
type testFormProcessor struct{}
func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error {
if formName != "file" && fileName != "my.file" {
......@@ -39,6 +38,10 @@ func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writ
return nil
}
func (a *testFormProcessor) Finalize() error {
return nil
}
func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder()
request, err := http.NewRequest("", "", nil)
......@@ -214,6 +217,47 @@ func TestUploadProcessingFile(t *testing.T) {
testhelper.AssertResponseCode(t, response, 500)
}
func TestInvalidFileNames(t *testing.T) {
testhelper.ConfigureSecret()
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
for _, testCase := range []struct {
filename string
code int
}{
{"foobar", 200}, // sanity check for test setup below
{"foo/bar", 500},
{"/../../foobar", 500},
{".", 500},
{"..", 500},
} {
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
file, err := writer.CreateFormFile("file", testCase.filename)
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("POST", "/example", buffer)
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &savedFileTracker{request: httpRequest})
testhelper.AssertResponseCode(t, response, testCase.code)
}
}
func newProxy(url string) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
......
......@@ -2,6 +2,7 @@ package upstream
import (
"net/http"
"path"
"regexp"
"github.com/gorilla/websocket"
......@@ -17,6 +18,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
)
type matcherFunc func(*http.Request) bool
......@@ -91,7 +93,6 @@ func (u *Upstream) configureRoutes() {
api := apipkg.NewAPI(
u.Backend,
u.Version,
u.SecretPath,
u.RoundTripper,
)
static := &staticpages.Static{u.DocumentRoot}
......@@ -109,7 +110,9 @@ func (u *Upstream) configureRoutes() {
git.SendPatch,
artifacts.SendEntry,
)
ciAPIProxyQueue := queueing.QueueRequests(proxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
uploadAccelerateProxy := upload.Accelerate(path.Join(u.DocumentRoot, "uploads/tmp"), proxy)
ciAPIProxyQueue := queueing.QueueRequests(uploadAccelerateProxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
u.Routes = []routeEntry{
// Git Clone
......@@ -153,7 +156,7 @@ func (u *Upstream) configureRoutes() {
static.ServeExisting(
u.URLPrefix,
staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, proxy)),
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, uploadAccelerateProxy)),
),
),
}
......
......@@ -15,15 +15,20 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix"
)
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
var (
DefaultBackend = helper.URLMustParse("http://localhost:8080")
requestHeaderBlacklist = []string{
upload.RewrittenFieldsHeader,
}
)
type Config struct {
Backend *url.URL
Version string
SecretPath string
DocumentRoot string
DevelopmentMode bool
Socket string
......@@ -103,5 +108,9 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
return
}
for _, h := range requestHeaderBlacklist {
r.Header.Del(h)
}
route.handler.ServeHTTP(w, r)
}
......@@ -25,6 +25,7 @@ import (
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
"github.com/prometheus/client_golang/prometheus/promhttp"
......@@ -106,11 +107,11 @@ func main() {
}()
}
secret.SetPath(*secretPath)
upConfig := upstream.Config{
Backend: backendURL,
Socket: *authSocket,
Version: Version,
SecretPath: *secretPath,
DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode,
ProxyHeadersTimeout: *proxyHeadersTimeout,
......
......@@ -8,7 +8,6 @@ import (
"io"
"io/ioutil"
"log"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
......@@ -371,52 +370,6 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
}
}
func TestArtifactsUpload(t *testing.T) {
reqBody := &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
writer.Close()
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
var sendDataHeader = "Gitlab-Workhorse-Send-Data"
func sendDataResponder(command string, literalJSON string) *httptest.Server {
......@@ -691,10 +644,10 @@ func archiveOKServer(t *testing.T, archiveName string) *httptest.Server {
}
func startWorkhorseServer(authBackend string) *httptest.Server {
testhelper.ConfigureSecret()
config := upstream.Config{
Backend: helper.URLMustParse(authBackend),
Version: "123",
SecretPath: testhelper.SecretPath(),
DocumentRoot: testDocumentRoot,
}
......
package main
import (
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
"github.com/dgrijalva/jwt-go"
)
func TestArtifactsUpload(t *testing.T) {
reqBody, contentType, err := multipartBodyWithFile()
if err != nil {
t.Fatal(err)
}
ts := uploadTestServer(t, nil)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, contentType, reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.Server {
return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
if extraTests != nil {
extraTests(r)
}
w.WriteHeader(200)
})
}
func TestAcceleratedUpload(t *testing.T) {
reqBody, contentType, err := multipartBodyWithFile()
if err != nil {
t.Fatal(err)
}
ts := uploadTestServer(t, func(r *http.Request) {
jwtToken, err := jwt.Parse(r.Header.Get(upload.RewrittenFieldsHeader), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
testhelper.ConfigureSecret()
secretBytes, err := secret.Bytes()
if err != nil {
return nil, fmt.Errorf("read secret from file: %v", err)
}
return secretBytes, nil
})
if err != nil {
t.Fatal(err)
}
rewrittenFields := jwtToken.Claims.(jwt.MapClaims)["rewritten_fields"].(map[string]interface{})
if len(rewrittenFields) != 1 || len(rewrittenFields["file"].(string)) == 0 {
t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields)
}
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/example`
resp, err := http.Post(ws.URL+resource, contentType, reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func multipartBodyWithFile() (io.Reader, string, error) {
result := &bytes.Buffer{}
writer := multipart.NewWriter(result)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
return nil, "", err
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
return result, writer.FormDataContentType(), writer.Close()
}
func TestBlockingRewrittenFieldsHeader(t *testing.T) {
canary := "untrusted header passed by user"
testCases := []struct {
desc string
contentType string
body io.Reader
present bool
}{
{"multipart with file", "", nil, true}, // placeholder
{"no multipart", "text/plain", nil, false},
}
if b, c, err := multipartBodyWithFile(); err == nil {
testCases[0].contentType = c
testCases[0].body = b
} else {
t.Fatal(err)
}
for _, tc := range testCases {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
h := upload.RewrittenFieldsHeader
if _, ok := r.Header[h]; ok != tc.present {
t.Errorf("Expectation of presence (%v) violated", tc.present)
}
if r.Header.Get(h) == canary {
t.Errorf("Found canary %q in header %q", canary, h)
}
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
req, err := http.NewRequest("POST", ws.URL+"/something", tc.body)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", tc.contentType)
req.Header.Set(upload.RewrittenFieldsHeader, canary)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment