Commit 304988d8 authored by Alessio Caiazza's avatar Alessio Caiazza

Unify uploads handling under filestore package

parent e0f0ad10
package filestore
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
)
// FileHandler represent a file that has been processed for upload
// it may be either uploaded to an ObjectStore and/or saved on local path.
// Remote upload is not yet implemented
type FileHandler struct {
// LocalPath is the path on the disk where file has been stored
LocalPath string
// Size is the persisted file size
Size int64
// Name is the resource name to send back to GitLab rails.
// It differ from the real file name in order to avoid file collisions
Name string
// a map containing different hashes
hashes map[string]string
}
// SHA256 hash of the handled file
func (fh *FileHandler) SHA256() string {
return fh.hashes["sha256"]
}
// MD5 hash of the handled file
func (fh *FileHandler) MD5() string {
return fh.hashes["md5"]
}
// GitLabFinalizeFields returns a map with all the fields GitLab Rails needs in order to finalize the upload.
func (fh *FileHandler) GitLabFinalizeFields(prefix string) map[string]string {
data := make(map[string]string)
key := func(field string) string {
if prefix == "" {
return field
}
return fmt.Sprintf("%s.%s", prefix, field)
}
if fh.Name != "" {
data[key("name")] = fh.Name
}
if fh.LocalPath != "" {
data[key("path")] = fh.LocalPath
}
data[key("size")] = strconv.FormatInt(fh.Size, 10)
for hashName, hash := range fh.hashes {
data[key(hashName)] = hash
}
return data
}
// SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done
// Make sure the provided context will not expire before finalizing upload with GitLab Rails.
func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) {
fh = &FileHandler{Name: opts.TempFilePrefix}
hashes := newMultiHash()
writers := []io.Writer{hashes.Writer}
defer func() {
for _, w := range writers {
if closer, ok := w.(io.WriteCloser); ok {
closer.Close()
}
}
}()
if opts.IsLocal() {
fileWriter, err := fh.uploadLocalFile(ctx, opts)
if err != nil {
return nil, err
}
writers = append(writers, fileWriter)
}
if len(writers) == 1 {
return nil, errors.New("Missing upload destination")
}
multiWriter := io.MultiWriter(writers...)
fh.Size, err = io.Copy(multiWriter, reader)
if err != nil {
return nil, err
}
if size != -1 && size != fh.Size {
return nil, fmt.Errorf("Expected %d bytes but got only %d", size, fh.Size)
}
fh.hashes = hashes.finish()
return fh, err
}
func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (io.WriteCloser, error) {
// make sure TempFolder exists
err := os.MkdirAll(opts.LocalTempPath, 0700)
if err != nil {
return nil, fmt.Errorf("uploadLocalFile: mkdir %q: %v", opts.LocalTempPath, err)
}
file, err := ioutil.TempFile(opts.LocalTempPath, opts.TempFilePrefix)
if err != nil {
return nil, fmt.Errorf("uploadLocalFile: create file: %v", err)
}
go func() {
<-ctx.Done()
os.Remove(file.Name())
}()
fh.LocalPath = file.Name()
return file, nil
}
package filestore_test
import (
"context"
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
)
// Some usefull const for testing purpose
const (
// testContent an example textual content
testContent = "TEST OBJECT CONTENT"
// testSize is the testContent size
testSize = int64(len(testContent))
// testMD5 is testContent MD5 hash
testMD5 = "42d000eea026ee0760677e506189cb33"
// testSHA256 is testContent SHA256 hash
testSHA256 = "b0257e9e657ef19b15eed4fbba975bd5238d651977564035ef91cb45693647aa"
)
func TestSaveFileFromReader(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
require.NoError(err)
defer os.RemoveAll(tmpFolder)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file"}
fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(testContent), testSize, opts)
assert.NoError(err)
require.NotNil(fh)
assert.NotEmpty(fh.LocalPath, "File hasn't been persisted on disk")
_, err = os.Stat(fh.LocalPath)
assert.NoError(err)
assert.Equal(testMD5, fh.MD5())
assert.Equal(testSHA256, fh.SHA256())
cancel()
time.Sleep(100 * time.Millisecond)
_, err = os.Stat(fh.LocalPath)
assert.Error(err)
assert.True(os.IsNotExist(err), "File hasn't been deleted during cleanup")
}
func TestSaveFileWrongSize(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tmpFolder, err := ioutil.TempDir("", "workhorse-test-tmp")
require.NoError(err)
defer os.RemoveAll(tmpFolder)
opts := &filestore.SaveFileOpts{LocalTempPath: tmpFolder, TempFilePrefix: "test-file"}
fh, err := filestore.SaveFileFromReader(ctx, strings.NewReader(testContent), testSize+1, opts)
assert.Error(err)
assert.EqualError(err, fmt.Sprintf("Expected %d bytes but got only %d", testSize+1, testSize))
assert.Nil(fh)
}
package filestore
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"hash"
"io"
)
var hashFactories = map[string](func() hash.Hash){
"md5": md5.New,
"sha1": sha1.New,
"sha256": sha256.New,
"sha512": sha512.New,
}
type multiHash struct {
io.Writer
hashes map[string]hash.Hash
}
func newMultiHash() (m *multiHash) {
m = &multiHash{}
m.hashes = make(map[string]hash.Hash)
var writers []io.Writer
for hash, hashFactory := range hashFactories {
writer := hashFactory()
m.hashes[hash] = writer
writers = append(writers, writer)
}
m.Writer = io.MultiWriter(writers...)
return m
}
func (m *multiHash) finish() map[string]string {
h := make(map[string]string)
for hashName, hash := range m.hashes {
checksum := hash.Sum(nil)
h[hashName] = hex.EncodeToString(checksum)
}
return h
}
package filestore
// SaveFileOpts represents all the options available for saving a file to object store
type SaveFileOpts struct {
// TempFilePrefix is the prefix used to create temporary local file
TempFilePrefix string
// LocalTempPath is the directory where to write a local copy of the file
LocalTempPath string
}
// IsLocal checks if the options require the writing of the file on disk
func (s *SaveFileOpts) IsLocal() bool {
return s.LocalTempPath != ""
}
......@@ -5,83 +5,58 @@ In this file we handle git lfs objects downloads and uploads
package lfs
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"net/url"
"path/filepath"
"strings"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
)
func PutStore(a *api.API, h http.Handler) http.Handler {
return lfsAuthorizeHandler(a, handleStoreLfsObject(h))
return handleStoreLFSObject(a, h)
}
func lfsAuthorizeHandler(myAPI *api.API, handleFunc api.HandleFunc) http.Handler {
func handleStoreLFSObject(myAPI *api.API, h http.Handler) http.Handler {
return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
if a.StoreLFSPath == "" {
helper.Fail500(w, r, fmt.Errorf("lfsAuthorizeHandler: StoreLFSPath empty"))
return
}
if a.LfsOid == "" {
helper.Fail500(w, r, fmt.Errorf("lfsAuthorizeHandler: LfsOid empty"))
return
}
if err := os.MkdirAll(a.StoreLFSPath, 0700); err != nil {
helper.Fail500(w, r, fmt.Errorf("lfsAuthorizeHandler: mkdir StoreLFSPath: %v", err))
return
opts := &filestore.SaveFileOpts{
LocalTempPath: a.StoreLFSPath,
TempFilePrefix: a.LfsOid,
}
handleFunc(w, r, a)
}, "/authorize")
}
func handleStoreLfsObject(h http.Handler) api.HandleFunc {
return func(w http.ResponseWriter, r *http.Request, a *api.Response) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
fh, err := filestore.SaveFileFromReader(r.Context(), r.Body, r.ContentLength, opts)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
helper.Fail500(w, r, fmt.Errorf("handleStoreLFSObject: copy body to tempfile: %v", err))
return
}
defer os.Remove(file.Name())
defer file.Close()
hash := sha256.New()
hw := io.MultiWriter(hash, file)
written, err := io.Copy(hw, r.Body)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("handleStoreLfsObject: copy body to tempfile: %v", err))
if fh.Size != a.LfsSize {
helper.Fail500(w, r, fmt.Errorf("handleStoreLFSObject: expected size %d, wrote %d", a.LfsSize, fh.Size))
return
}
file.Close()
if written != a.LfsSize {
helper.Fail500(w, r, fmt.Errorf("handleStoreLfsObject: expected size %d, wrote %d", a.LfsSize, written))
if fh.SHA256() != a.LfsOid {
helper.Fail500(w, r, fmt.Errorf("handleStoreLFSObject: expected sha256 %s, got %s", a.LfsOid, fh.SHA256()))
return
}
shaStr := hex.EncodeToString(hash.Sum(nil))
if shaStr != a.LfsOid {
helper.Fail500(w, r, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", a.LfsOid, shaStr))
return
data := url.Values{}
for k, v := range fh.GitLabFinalizeFields("file") {
data.Set(k, v)
}
// Inject header and body
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name()))
r.Body = ioutil.NopCloser(&bytes.Buffer{})
r.ContentLength = 0
// Hijack body
body := data.Encode()
r.Body = ioutil.NopCloser(strings.NewReader(body))
r.ContentLength = int64(len(body))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(fh.LocalPath))
// And proxy the request
h.ServeHTTP(w, r)
}
}, "/authorize")
}
......@@ -4,14 +4,13 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"os"
"path"
"strings"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/filestore"
)
var (
......@@ -42,10 +41,9 @@ var (
)
type rewriter struct {
writer *multipart.Writer
tempPath string
filter MultipartFormProcessor
directories []string
writer *multipart.Writer
tempPath string
filter MultipartFormProcessor
}
func init() {
......@@ -54,15 +52,15 @@ func init() {
prometheus.MustRegister(multipartFiles)
}
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) {
func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) error {
// Create multipart reader
reader, err := r.MultipartReader()
if err != nil {
if err == http.ErrNotMultipart {
// We want to be able to recognize http.ErrNotMultipart elsewhere so no fmt.Errorf
return nil, http.ErrNotMultipart
return http.ErrNotMultipart
}
return nil, fmt.Errorf("get multipart reader: %v", err)
return fmt.Errorf("get multipart reader: %v", err)
}
multipartUploadRequests.WithLabelValues(filter.Name()).Inc()
......@@ -73,26 +71,13 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
filter: filter,
}
cleanup = func() {
for _, dir := range rew.directories {
os.RemoveAll(dir)
}
}
// Execute cleanup in case of failure
defer func() {
if err != nil {
cleanup()
}
}()
for {
p, err := reader.NextPart()
if err != nil {
if err == io.EOF {
break
}
return cleanup, err
return err
}
name := p.FormName()
......@@ -109,11 +94,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
}
if err != nil {
return cleanup, err
return err
}
}
return cleanup, nil
return nil
}
func (rew *rewriter) handleFilePart(ctx context.Context, name string, p *multipart.Part) error {
......@@ -125,40 +110,23 @@ func (rew *rewriter) handleFilePart(ctx context.Context, name string, p *multipa
return fmt.Errorf("illegal filename: %q", filename)
}
// Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(rew.tempPath, 0700); err != nil {
return fmt.Errorf("mkdir for tempfile: %v", err)
opts := &filestore.SaveFileOpts{
LocalTempPath: rew.tempPath,
TempFilePrefix: filename,
}
tempDir, err := ioutil.TempDir(rew.tempPath, "multipart-")
fh, err := filestore.SaveFileFromReader(ctx, p, -1, opts)
if err != nil {
return fmt.Errorf("create tempdir: %v", err)
return fmt.Errorf("Persisting multipart file: %v", err)
}
rew.directories = append(rew.directories, tempDir)
file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err)
}
defer file.Close()
// Add file entry
rew.writer.WriteField(name+".path", file.Name())
rew.writer.WriteField(name+".name", filename)
written, err := io.Copy(file, p)
if err != nil {
return fmt.Errorf("copy from multipart to tempfile: %v", err)
for key, value := range fh.GitLabFinalizeFields(name) {
rew.writer.WriteField(key, value)
}
multipartFileUploadBytes.WithLabelValues(rew.filter.Name()).Add(float64(written))
file.Close()
if err := rew.filter.ProcessFile(ctx, name, file.Name(), rew.writer); err != nil {
return err
}
multipartFileUploadBytes.WithLabelValues(rew.filter.Name()).Add(float64(fh.Size))
return nil
return rew.filter.ProcessFile(ctx, name, fh.LocalPath, rew.writer)
}
func (rew *rewriter) copyPart(ctx context.Context, name string, p *multipart.Part) error {
......
......@@ -30,7 +30,7 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t
defer writer.Close()
// Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath, filter)
err := rewriteFormFilesFromMultipart(r, writer, tempPath, filter)
if err != nil {
if err == http.ErrNotMultipart {
h.ServeHTTP(w, r)
......@@ -40,10 +40,6 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t
return
}
if cleanup != nil {
defer cleanup()
}
// Close writer
writer.Close()
......
......@@ -14,6 +14,7 @@ import (
"regexp"
"strings"
"testing"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
......@@ -114,8 +115,8 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
t.Fatal(err)
}
if len(r.MultipartForm.Value) != 3 {
t.Fatal("Expected to receive exactly 3 values")
if len(r.MultipartForm.Value) != 8 {
t.Fatal("Expected to receive exactly 8 values")
}
if len(r.MultipartForm.File) != 0 {
......@@ -136,6 +137,23 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
t.Fatal("Expected to the file to be in tempPath")
}
if r.FormValue("file.size") != "4" {
t.Fatal("Expected to receive the file size")
}
hashes := map[string]string{
"md5": "098f6bcd4621d373cade4e832627b4f6",
"sha1": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3",
"sha256": "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08",
"sha512": "ee26b0dd4af7e749aa1a8ee3c10ae9923f618980772e473f8819a5d4940e0db27ac185f8a0e1d5f84f88bc887fd67b143732c304cc5fa9ad8e6f57f50028a8ff",
}
for algo, hash := range hashes {
if r.FormValue("file."+algo) != hash {
t.Fatalf("Expected to receive file %s hash", algo)
}
}
w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE")
})
......@@ -156,6 +174,8 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
httpRequest = httpRequest.WithContext(ctx)
httpRequest.Body = ioutil.NopCloser(&buffer)
httpRequest.ContentLength = int64(buffer.Len())
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
......@@ -165,7 +185,18 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) {
cancel() // this will trigger an async cleanup
// Poll because the file removal is async
for i := 0; i < 100; i++ {
_, err = os.Stat(filePath)
if err != nil {
break
}
time.Sleep(100 * time.Millisecond)
}
if !os.IsNotExist(err) {
t.Fatal("expected the file to be deleted")
}
}
......
......@@ -70,7 +70,7 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.
if err != nil {
t.Fatal(err)
}
nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file)
nValues := 7 // file name, path, size, md5, sha1, sha256, sha512 for just the upload (no metadata because we are not POSTing a valid zip file)
if len(r.MultipartForm.Value) != nValues {
t.Errorf("Expected to receive exactly %d values", nValues)
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment