Commit cf40f383 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'accelerate' into 'master'

Accelerate multipart file uploads for all requests

Companion to https://gitlab.com/gitlab-org/gitlab-ce/merge_requests/5867

Extends the upload mechanism for CI artifacts to all multipart file uploads.

See merge request !58
parents d01ee210 9c5c1d5f
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
...@@ -32,7 +33,8 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur ...@@ -32,7 +33,8 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur
t.Fatal(err) t.Fatal(err)
} }
parsedURL := helper.URLMustParse(ts.URL) parsedURL := helper.URLMustParse(ts.URL)
a := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), badgateway.TestRoundTripper(parsedURL)) testhelper.ConfigureSecret()
a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
response := httptest.NewRecorder() response := httptest.NewRecorder()
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest) a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
...@@ -86,7 +88,8 @@ func TestPreAuthorizeJWT(t *testing.T) { ...@@ -86,7 +88,8 @@ func TestPreAuthorizeJWT(t *testing.T) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
} }
secretBytes, err := (&api.Secret{Path: testhelper.SecretPath()}).Bytes() testhelper.ConfigureSecret()
secretBytes, err := secret.Bytes()
if err != nil { if err != nil {
return nil, fmt.Errorf("read secret from file: %v", err) return nil, fmt.Errorf("read secret from file: %v", err)
} }
......
...@@ -11,8 +11,7 @@ import ( ...@@ -11,8 +11,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"github.com/dgrijalva/jwt-go"
) )
// Custom content type for API responses, to catch routing / programming mistakes // Custom content type for API responses, to catch routing / programming mistakes
...@@ -24,15 +23,13 @@ type API struct { ...@@ -24,15 +23,13 @@ type API struct {
Client *http.Client Client *http.Client
URL *url.URL URL *url.URL
Version string Version string
Secret *Secret
} }
func NewAPI(myURL *url.URL, version, secretPath string, roundTripper *badgateway.RoundTripper) *API { func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
return &API{ return &API{
Client: &http.Client{Transport: roundTripper}, Client: &http.Client{Transport: roundTripper},
URL: myURL, URL: myURL,
Version: version, Version: version,
Secret: &Secret{Path: secretPath},
} }
} }
...@@ -130,13 +127,7 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt ...@@ -130,13 +127,7 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt
// configurations (Passenger) to solve auth request routing problems. // configurations (Passenger) to solve auth request routing problems.
authReq.Header.Set("Gitlab-Workhorse", api.Version) authReq.Header.Set("Gitlab-Workhorse", api.Version)
secretBytes, err := api.Secret.Bytes() tokenString, err := secret.JWTTokenString(secret.DefaultClaims)
if err != nil {
return nil, fmt.Errorf("newRequest: %v", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{Issuer: "gitlab-workhorse"})
tokenString, err := token.SignedString(secretBytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("newRequest: sign JWT: %v", err) return nil, fmt.Errorf("newRequest: sign JWT: %v", err)
} }
......
...@@ -65,6 +65,10 @@ func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipa ...@@ -65,6 +65,10 @@ func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipa
return nil return nil
} }
func (a *artifactsUploadProcessor) Finalize() error {
return nil
}
func (a *artifactsUploadProcessor) Cleanup() { func (a *artifactsUploadProcessor) Cleanup() {
if a.metadataFile != "" { if a.metadataFile != "" {
os.Remove(a.metadataFile) os.Remove(a.metadataFile)
......
...@@ -93,7 +93,8 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h ...@@ -93,7 +93,8 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h
response := httptest.NewRecorder() response := httptest.NewRecorder()
parsedURL := helper.URLMustParse(ts.URL) parsedURL := helper.URLMustParse(ts.URL)
roundTripper := badgateway.TestRoundTripper(parsedURL) roundTripper := badgateway.TestRoundTripper(parsedURL)
apiClient := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), roundTripper) testhelper.ConfigureSecret()
apiClient := api.NewAPI(parsedURL, "123", roundTripper)
proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper) proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest) UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response return response
......
package secret
import (
"fmt"
"github.com/dgrijalva/jwt-go"
)
var (
DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"}
)
func JWTTokenString(claims jwt.Claims) (string, error) {
secretBytes, err := Bytes()
if err != nil {
return "", fmt.Errorf("secret.JWTTokenString: %v", err)
}
tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes)
if err != nil {
return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err)
}
return tokenString, nil
}
package api package secret
import ( import (
"encoding/base64" "encoding/base64"
...@@ -9,52 +9,69 @@ import ( ...@@ -9,52 +9,69 @@ import (
const numSecretBytes = 32 const numSecretBytes = 32
type Secret struct { type sec struct {
Path string path string
bytes []byte bytes []byte
sync.RWMutex sync.RWMutex
} }
var (
theSecret = &sec{}
)
func SetPath(path string) {
theSecret.Lock()
defer theSecret.Unlock()
theSecret.path = path
theSecret.bytes = nil
}
// Lazy access to the HMAC secret key. We must be lazy because if the key // Lazy access to the HMAC secret key. We must be lazy because if the key
// is not already there, it will be generated by gitlab-rails, and // is not already there, it will be generated by gitlab-rails, and
// gitlab-rails is slow. // gitlab-rails is slow.
func (s *Secret) Bytes() ([]byte, error) { func Bytes() ([]byte, error) {
if bytes := s.getBytes(); bytes != nil { if bytes := getBytes(); bytes != nil {
return bytes, nil return copyBytes(bytes), nil
} }
return s.setBytes() return setBytes()
}
func getBytes() []byte {
theSecret.RLock()
defer theSecret.RUnlock()
return theSecret.bytes
} }
func (s *Secret) getBytes() []byte { func copyBytes(bytes []byte) []byte {
s.RLock() out := make([]byte, len(bytes))
defer s.RUnlock() copy(out, bytes)
return s.bytes return out
} }
func (s *Secret) setBytes() ([]byte, error) { func setBytes() ([]byte, error) {
s.Lock() theSecret.Lock()
defer s.Unlock() defer theSecret.Unlock()
if s.bytes != nil { if theSecret.bytes != nil {
return s.bytes, nil return theSecret.bytes, nil
} }
base64Bytes, err := ioutil.ReadFile(s.Path) base64Bytes, err := ioutil.ReadFile(theSecret.path)
if err != nil { if err != nil {
return nil, fmt.Errorf("read Secret.Path: %v", err) return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err)
} }
secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes))) secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes)))
n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes) n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("decode secret: %v", err) return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err)
} }
if n != numSecretBytes { if n != numSecretBytes {
return nil, fmt.Errorf("expected %d secretBytes in %s, found %d", numSecretBytes, s.Path, n) return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n)
} }
s.bytes = secretBytes theSecret.bytes = secretBytes
return s.bytes, nil return copyBytes(theSecret.bytes), nil
} }
...@@ -16,10 +16,12 @@ import ( ...@@ -16,10 +16,12 @@ import (
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
) )
func SecretPath() string { func ConfigureSecret() {
return path.Join(RootDir(), "testdata/test-secret") secret.SetPath(path.Join(RootDir(), "testdata/test-secret"))
} }
var extractPatchSeriesMatcher = regexp.MustCompile(`^From (\w+)`) var extractPatchSeriesMatcher = regexp.MustCompile(`^From (\w+)`)
......
package upload
import (
"fmt"
"mime/multipart"
"net/http"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"github.com/dgrijalva/jwt-go"
)
const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields"
type savedFileTracker struct {
request *http.Request
rewrittenFields map[string]string
}
type MultipartClaims struct {
RewrittenFields map[string]string `json:"rewritten_fields"`
jwt.StandardClaims
}
func Accelerate(tempDir string, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s := &savedFileTracker{request: r}
HandleFileUploads(w, r, h, tempDir, s)
})
}
func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart.Writer) error {
if s.rewrittenFields == nil {
s.rewrittenFields = make(map[string]string)
}
s.rewrittenFields[fieldName] = fileName
return nil
}
func (_ *savedFileTracker) ProcessField(_ string, _ *multipart.Writer) error {
return nil
}
func (s *savedFileTracker) Finalize() error {
if s.rewrittenFields == nil {
return nil
}
claims := MultipartClaims{s.rewrittenFields, secret.DefaultClaims}
tokenString, err := secret.JWTTokenString(claims)
if err != nil {
return fmt.Errorf("savedFileTracker.Finalize: %v", err)
}
s.request.Header.Set(RewrittenFieldsHeader, tokenString)
return nil
}
...@@ -8,13 +8,17 @@ import ( ...@@ -8,13 +8,17 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os" "os"
"path"
"strings"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
) )
// These methods are allowed to have thread-unsafe implementations.
type MultipartFormProcessor interface { type MultipartFormProcessor interface {
ProcessFile(formName, fileName string, writer *multipart.Writer) error ProcessFile(formName, fileName string, writer *multipart.Writer) error
ProcessField(formName string, writer *multipart.Writer) error ProcessField(formName string, writer *multipart.Writer) error
Finalize() error
} }
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) { func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) {
...@@ -28,11 +32,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -28,11 +32,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
return nil, fmt.Errorf("get multipart reader: %v", err) return nil, fmt.Errorf("get multipart reader: %v", err)
} }
var files []string var directories []string
cleanup = func() { cleanup = func() {
for _, file := range files { for _, dir := range directories {
os.Remove(file) os.RemoveAll(dir)
} }
} }
...@@ -56,22 +60,30 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -56,22 +60,30 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
// Copy form field // Copy form field
if filename := p.FileName(); filename != "" { if filename := p.FileName(); filename != "" {
if strings.Contains(filename, "/") || filename == "." || filename == ".." {
return cleanup, fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored // Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(tempPath, 0700); err != nil { if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, fmt.Errorf("mkdir for tempfile: %v", err) return cleanup, fmt.Errorf("mkdir for tempfile: %v", err)
} }
// Create temporary file in path returned by Authorization filter tempDir, err := ioutil.TempDir(tempPath, "multipart-")
file, err := ioutil.TempFile(tempPath, "upload_") if err != nil {
return cleanup, fmt.Errorf("create tempdir: %v", err)
}
directories = append(directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil { if err != nil {
return cleanup, fmt.Errorf("create tempfile: %v", err) return cleanup, fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
} }
defer file.Close() defer file.Close()
// Add file entry // Add file entry
writer.WriteField(name+".path", file.Name()) writer.WriteField(name+".path", file.Name())
writer.WriteField(name+".name", filename) writer.WriteField(name+".name", filename)
files = append(files, file.Name())
_, err = io.Copy(file, p) _, err = io.Copy(file, p)
if err != nil { if err != nil {
...@@ -135,6 +147,11 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t ...@@ -135,6 +147,11 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t
r.ContentLength = int64(body.Len()) r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
if err := filter.Finalize(); err != nil {
helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err))
return
}
// Proxy the request // Proxy the request
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
} }
...@@ -22,8 +22,7 @@ import ( ...@@ -22,8 +22,7 @@ import (
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
type testFormProcessor struct { type testFormProcessor struct{}
}
func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error { func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error {
if formName != "file" && fileName != "my.file" { if formName != "file" && fileName != "my.file" {
...@@ -39,6 +38,10 @@ func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writ ...@@ -39,6 +38,10 @@ func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writ
return nil return nil
} }
func (a *testFormProcessor) Finalize() error {
return nil
}
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request, err := http.NewRequest("", "", nil) request, err := http.NewRequest("", "", nil)
...@@ -214,6 +217,47 @@ func TestUploadProcessingFile(t *testing.T) { ...@@ -214,6 +217,47 @@ func TestUploadProcessingFile(t *testing.T) {
testhelper.AssertResponseCode(t, response, 500) testhelper.AssertResponseCode(t, response, 500)
} }
func TestInvalidFileNames(t *testing.T) {
testhelper.ConfigureSecret()
tempPath, err := ioutil.TempDir("", "uploads")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempPath)
for _, testCase := range []struct {
filename string
code int
}{
{"foobar", 200}, // sanity check for test setup below
{"foo/bar", 500},
{"/../../foobar", 500},
{".", 500},
{"..", 500},
} {
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
file, err := writer.CreateFormFile("file", testCase.filename)
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "test")
writer.Close()
httpRequest, err := http.NewRequest("POST", "/example", buffer)
if err != nil {
t.Fatal(err)
}
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder()
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &savedFileTracker{request: httpRequest})
testhelper.AssertResponseCode(t, response, testCase.code)
}
}
func newProxy(url string) *proxy.Proxy { func newProxy(url string) *proxy.Proxy {
parsedURL := helper.URLMustParse(url) parsedURL := helper.URLMustParse(url)
return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL)) return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
......
...@@ -2,6 +2,7 @@ package upstream ...@@ -2,6 +2,7 @@ package upstream
import ( import (
"net/http" "net/http"
"path"
"regexp" "regexp"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -17,6 +18,7 @@ import ( ...@@ -17,6 +18,7 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile" "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages" "gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal" "gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
) )
type matcherFunc func(*http.Request) bool type matcherFunc func(*http.Request) bool
...@@ -91,7 +93,6 @@ func (u *Upstream) configureRoutes() { ...@@ -91,7 +93,6 @@ func (u *Upstream) configureRoutes() {
api := apipkg.NewAPI( api := apipkg.NewAPI(
u.Backend, u.Backend,
u.Version, u.Version,
u.SecretPath,
u.RoundTripper, u.RoundTripper,
) )
static := &staticpages.Static{u.DocumentRoot} static := &staticpages.Static{u.DocumentRoot}
...@@ -109,7 +110,9 @@ func (u *Upstream) configureRoutes() { ...@@ -109,7 +110,9 @@ func (u *Upstream) configureRoutes() {
git.SendPatch, git.SendPatch,
artifacts.SendEntry, artifacts.SendEntry,
) )
ciAPIProxyQueue := queueing.QueueRequests(proxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
uploadAccelerateProxy := upload.Accelerate(path.Join(u.DocumentRoot, "uploads/tmp"), proxy)
ciAPIProxyQueue := queueing.QueueRequests(uploadAccelerateProxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout)
u.Routes = []routeEntry{ u.Routes = []routeEntry{
// Git Clone // Git Clone
...@@ -153,7 +156,7 @@ func (u *Upstream) configureRoutes() { ...@@ -153,7 +156,7 @@ func (u *Upstream) configureRoutes() {
static.ServeExisting( static.ServeExisting(
u.URLPrefix, u.URLPrefix,
staticpages.CacheDisabled, staticpages.CacheDisabled,
static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, proxy)), static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, uploadAccelerateProxy)),
), ),
), ),
} }
......
...@@ -15,15 +15,20 @@ import ( ...@@ -15,15 +15,20 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix"
) )
var DefaultBackend = helper.URLMustParse("http://localhost:8080") var (
DefaultBackend = helper.URLMustParse("http://localhost:8080")
requestHeaderBlacklist = []string{
upload.RewrittenFieldsHeader,
}
)
type Config struct { type Config struct {
Backend *url.URL Backend *url.URL
Version string Version string
SecretPath string
DocumentRoot string DocumentRoot string
DevelopmentMode bool DevelopmentMode bool
Socket string Socket string
...@@ -103,5 +108,9 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -103,5 +108,9 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
return return
} }
for _, h := range requestHeaderBlacklist {
r.Header.Del(h)
}
route.handler.ServeHTTP(w, r) route.handler.ServeHTTP(w, r)
} }
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
"time" "time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream" "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
...@@ -106,11 +107,11 @@ func main() { ...@@ -106,11 +107,11 @@ func main() {
}() }()
} }
secret.SetPath(*secretPath)
upConfig := upstream.Config{ upConfig := upstream.Config{
Backend: backendURL, Backend: backendURL,
Socket: *authSocket, Socket: *authSocket,
Version: Version, Version: Version,
SecretPath: *secretPath,
DocumentRoot: *documentRoot, DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode, DevelopmentMode: *developmentMode,
ProxyHeadersTimeout: *proxyHeadersTimeout, ProxyHeadersTimeout: *proxyHeadersTimeout,
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
...@@ -371,52 +370,6 @@ func TestDeniedPublicUploadsFile(t *testing.T) { ...@@ -371,52 +370,6 @@ func TestDeniedPublicUploadsFile(t *testing.T) {
} }
} }
func TestArtifactsUpload(t *testing.T) {
reqBody := &bytes.Buffer{}
writer := multipart.NewWriter(reqBody)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
t.Fatal(err)
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
writer.Close()
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
w.WriteHeader(200)
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
var sendDataHeader = "Gitlab-Workhorse-Send-Data" var sendDataHeader = "Gitlab-Workhorse-Send-Data"
func sendDataResponder(command string, literalJSON string) *httptest.Server { func sendDataResponder(command string, literalJSON string) *httptest.Server {
...@@ -691,10 +644,10 @@ func archiveOKServer(t *testing.T, archiveName string) *httptest.Server { ...@@ -691,10 +644,10 @@ func archiveOKServer(t *testing.T, archiveName string) *httptest.Server {
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
testhelper.ConfigureSecret()
config := upstream.Config{ config := upstream.Config{
Backend: helper.URLMustParse(authBackend), Backend: helper.URLMustParse(authBackend),
Version: "123", Version: "123",
SecretPath: testhelper.SecretPath(),
DocumentRoot: testDocumentRoot, DocumentRoot: testDocumentRoot,
} }
......
package main
import (
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upload"
"github.com/dgrijalva/jwt-go"
)
func TestArtifactsUpload(t *testing.T) {
reqBody, contentType, err := multipartBodyWithFile()
if err != nil {
t.Fatal(err)
}
ts := uploadTestServer(t, nil)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/ci/api/v1/builds/123/artifacts`
resp, err := http.Post(ws.URL+resource, contentType, reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.Server {
return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/authorize") {
w.Header().Set("Content-Type", api.ResponseContentType)
if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil {
t.Fatal(err)
}
return
}
err := r.ParseMultipartForm(100000)
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
if len(r.MultipartForm.File) != 0 {
t.Error("Expected to not receive any files")
}
if extraTests != nil {
extraTests(r)
}
w.WriteHeader(200)
})
}
func TestAcceleratedUpload(t *testing.T) {
reqBody, contentType, err := multipartBodyWithFile()
if err != nil {
t.Fatal(err)
}
ts := uploadTestServer(t, func(r *http.Request) {
jwtToken, err := jwt.Parse(r.Header.Get(upload.RewrittenFieldsHeader), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
testhelper.ConfigureSecret()
secretBytes, err := secret.Bytes()
if err != nil {
return nil, fmt.Errorf("read secret from file: %v", err)
}
return secretBytes, nil
})
if err != nil {
t.Fatal(err)
}
rewrittenFields := jwtToken.Claims.(jwt.MapClaims)["rewritten_fields"].(map[string]interface{})
if len(rewrittenFields) != 1 || len(rewrittenFields["file"].(string)) == 0 {
t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields)
}
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := `/example`
resp, err := http.Post(ws.URL+resource, contentType, reqBody)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode)
}
}
func multipartBodyWithFile() (io.Reader, string, error) {
result := &bytes.Buffer{}
writer := multipart.NewWriter(result)
file, err := writer.CreateFormFile("file", "my.file")
if err != nil {
return nil, "", err
}
fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART")
return result, writer.FormDataContentType(), writer.Close()
}
func TestBlockingRewrittenFieldsHeader(t *testing.T) {
canary := "untrusted header passed by user"
testCases := []struct {
desc string
contentType string
body io.Reader
present bool
}{
{"multipart with file", "", nil, true}, // placeholder
{"no multipart", "text/plain", nil, false},
}
if b, c, err := multipartBodyWithFile(); err == nil {
testCases[0].contentType = c
testCases[0].body = b
} else {
t.Fatal(err)
}
for _, tc := range testCases {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
h := upload.RewrittenFieldsHeader
if _, ok := r.Header[h]; ok != tc.present {
t.Errorf("Expectation of presence (%v) violated", tc.present)
}
if r.Header.Get(h) == canary {
t.Errorf("Found canary %q in header %q", canary, h)
}
})
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
req, err := http.NewRequest("POST", ws.URL+"/something", tc.body)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", tc.contentType)
req.Header.Set(upload.RewrittenFieldsHeader, canary)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment