Commit c609e4fa authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab) Committed by Nick Thomas

Ban context.Background

parent bf1d9467
...@@ -30,6 +30,7 @@ ${BUILD_DIR}/_build: ...@@ -30,6 +30,7 @@ ${BUILD_DIR}/_build:
.PHONY: test .PHONY: test
test: clean-build clean-workhorse all govendor test: clean-build clean-workhorse all govendor
go fmt ${PKG_ALL} | awk '{ print } END { if (NR > 0) { print "Please run go fmt"; exit 1 } }' go fmt ${PKG_ALL} | awk '{ print } END { if (NR > 0) { print "Please run go fmt"; exit 1 } }'
_support/detect-context.sh
cd ${GOPATH}/src/${PKG} && govendor sync cd ${GOPATH}/src/${PKG} && govendor sync
go test ${PKG_ALL} go test ${PKG_ALL}
@echo SUCCESS @echo SUCCESS
......
#!/bin/sh
git grep 'context.\(Background\|TODO\)' | grep -v -e '^[^:]*_test.go:' -e '^vendor/' -e '^_support/' | awk '{
print "Found disallowed use of context.Background or TODO"
print
exit 1
}'
package artifacts package artifacts
import ( import (
"context"
"fmt" "fmt"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os" "os"
"time" "time"
"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
...@@ -60,7 +58,7 @@ func init() { ...@@ -60,7 +58,7 @@ func init() {
objectStorageUploadBytes) objectStorageUploadBytes)
} }
func (a *artifactsUploadProcessor) storeFile(formName, fileName string, writer *multipart.Writer) error { func (a *artifactsUploadProcessor) storeFile(ctx context.Context, formName, fileName string, writer *multipart.Writer) error {
if a.ObjectStore.StoreURL == "" { if a.ObjectStore.StoreURL == "" {
return nil return nil
} }
...@@ -104,10 +102,11 @@ func (a *artifactsUploadProcessor) storeFile(formName, fileName string, writer * ...@@ -104,10 +102,11 @@ func (a *artifactsUploadProcessor) storeFile(formName, fileName string, writer *
timeout = a.ObjectStore.Timeout timeout = a.ObjectStore.Timeout
} }
ctx, cancelFn := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) ctx2, cancelFn := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancelFn() defer cancelFn()
req = req.WithContext(ctx2)
resp, err := ctxhttp.Do(ctx, http.DefaultClient, req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
objectStorageUploadRequestsRequestFailed.Inc() objectStorageUploadRequestsRequestFailed.Inc()
return fmt.Errorf("PUT request %q: %v", a.ObjectStore.StoreURL, err) return fmt.Errorf("PUT request %q: %v", a.ObjectStore.StoreURL, err)
......
package artifacts package artifacts
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -44,7 +45,7 @@ func (a *artifactsUploadProcessor) generateMetadataFromZip(fileName string, meta ...@@ -44,7 +45,7 @@ func (a *artifactsUploadProcessor) generateMetadataFromZip(fileName string, meta
return true, nil return true, nil
} }
func (a *artifactsUploadProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error { func (a *artifactsUploadProcessor) ProcessFile(ctx context.Context, formName, fileName string, writer *multipart.Writer) error {
// ProcessFile for artifacts requires file form-data field name to eq `file` // ProcessFile for artifacts requires file form-data field name to eq `file`
if formName != "file" { if formName != "file" {
...@@ -74,17 +75,17 @@ func (a *artifactsUploadProcessor) ProcessFile(formName, fileName string, writer ...@@ -74,17 +75,17 @@ func (a *artifactsUploadProcessor) ProcessFile(formName, fileName string, writer
writer.WriteField("metadata.name", "metadata.gz") writer.WriteField("metadata.name", "metadata.gz")
} }
if err := a.storeFile(formName, fileName, writer); err != nil { if err := a.storeFile(ctx, formName, fileName, writer); err != nil {
return fmt.Errorf("storeFile: %v", err) return fmt.Errorf("storeFile: %v", err)
} }
return nil return nil
} }
func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipart.Writer) error { func (a *artifactsUploadProcessor) ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error {
return nil return nil
} }
func (a *artifactsUploadProcessor) Finalize() error { func (a *artifactsUploadProcessor) Finalize(ctx context.Context) error {
return nil return nil
} }
......
package upload package upload
import ( import (
"context"
"fmt" "fmt"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
"github.com/dgrijalva/jwt-go" jwt "github.com/dgrijalva/jwt-go"
) )
const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields" const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields"
...@@ -29,7 +30,7 @@ func Accelerate(tempDir string, h http.Handler) http.Handler { ...@@ -29,7 +30,7 @@ func Accelerate(tempDir string, h http.Handler) http.Handler {
}) })
} }
func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart.Writer) error { func (s *savedFileTracker) ProcessFile(_ context.Context, fieldName, fileName string, _ *multipart.Writer) error {
if s.rewrittenFields == nil { if s.rewrittenFields == nil {
s.rewrittenFields = make(map[string]string) s.rewrittenFields = make(map[string]string)
} }
...@@ -37,11 +38,11 @@ func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart. ...@@ -37,11 +38,11 @@ func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart.
return nil return nil
} }
func (_ *savedFileTracker) ProcessField(_ string, _ *multipart.Writer) error { func (_ *savedFileTracker) ProcessField(_ context.Context, _ string, _ *multipart.Writer) error {
return nil return nil
} }
func (s *savedFileTracker) Finalize() error { func (s *savedFileTracker) Finalize(_ context.Context) error {
if s.rewrittenFields == nil { if s.rewrittenFields == nil {
return nil return nil
} }
......
package upload package upload
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -101,10 +102,10 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -101,10 +102,10 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
// Copy form field // Copy form field
if p.FileName() != "" { if p.FileName() != "" {
err = rew.handleFilePart(name, p) err = rew.handleFilePart(r.Context(), name, p)
} else { } else {
err = rew.copyPart(name, p) err = rew.copyPart(r.Context(), name, p)
} }
if err != nil { if err != nil {
...@@ -115,7 +116,7 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -115,7 +116,7 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
return cleanup, nil return cleanup, nil
} }
func (rew *rewriter) handleFilePart(name string, p *multipart.Part) error { func (rew *rewriter) handleFilePart(ctx context.Context, name string, p *multipart.Part) error {
multipartFiles.WithLabelValues(rew.filter.Name()).Inc() multipartFiles.WithLabelValues(rew.filter.Name()).Inc()
filename := p.FileName() filename := p.FileName()
...@@ -153,14 +154,14 @@ func (rew *rewriter) handleFilePart(name string, p *multipart.Part) error { ...@@ -153,14 +154,14 @@ func (rew *rewriter) handleFilePart(name string, p *multipart.Part) error {
file.Close() file.Close()
if err := rew.filter.ProcessFile(name, file.Name(), rew.writer); err != nil { if err := rew.filter.ProcessFile(ctx, name, file.Name(), rew.writer); err != nil {
return err return err
} }
return nil return nil
} }
func (rew *rewriter) copyPart(name string, p *multipart.Part) error { func (rew *rewriter) copyPart(ctx context.Context, name string, p *multipart.Part) error {
np, err := rew.writer.CreatePart(p.Header) np, err := rew.writer.CreatePart(p.Header)
if err != nil { if err != nil {
return fmt.Errorf("create multipart field: %v", err) return fmt.Errorf("create multipart field: %v", err)
...@@ -170,7 +171,7 @@ func (rew *rewriter) copyPart(name string, p *multipart.Part) error { ...@@ -170,7 +171,7 @@ func (rew *rewriter) copyPart(name string, p *multipart.Part) error {
return fmt.Errorf("duplicate multipart field: %v", err) return fmt.Errorf("duplicate multipart field: %v", err)
} }
if err := rew.filter.ProcessField(name, rew.writer); err != nil { if err := rew.filter.ProcessField(ctx, name, rew.writer); err != nil {
return fmt.Errorf("process multipart field: %v", err) return fmt.Errorf("process multipart field: %v", err)
} }
......
...@@ -2,6 +2,7 @@ package upload ...@@ -2,6 +2,7 @@ package upload
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"mime/multipart" "mime/multipart"
...@@ -12,9 +13,9 @@ import ( ...@@ -12,9 +13,9 @@ import (
// These methods are allowed to have thread-unsafe implementations. // These methods are allowed to have thread-unsafe implementations.
type MultipartFormProcessor interface { type MultipartFormProcessor interface {
ProcessFile(formName, fileName string, writer *multipart.Writer) error ProcessFile(ctx context.Context, formName, fileName string, writer *multipart.Writer) error
ProcessField(formName string, writer *multipart.Writer) error ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error
Finalize() error Finalize(ctx context.Context) error
Name() string Name() string
} }
...@@ -51,7 +52,7 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t ...@@ -51,7 +52,7 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t
r.ContentLength = int64(body.Len()) r.ContentLength = int64(body.Len())
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
if err := filter.Finalize(); err != nil { if err := filter.Finalize(r.Context()); err != nil {
helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err)) helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err))
return return
} }
......
...@@ -2,6 +2,7 @@ package upload ...@@ -2,6 +2,7 @@ package upload
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -24,21 +25,21 @@ var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) ...@@ -24,21 +25,21 @@ var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
type testFormProcessor struct{} type testFormProcessor struct{}
func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error { func (a *testFormProcessor) ProcessFile(ctx context.Context, formName, fileName string, writer *multipart.Writer) error {
if formName != "file" && fileName != "my.file" { if formName != "file" && fileName != "my.file" {
return errors.New("illegal file") return errors.New("illegal file")
} }
return nil return nil
} }
func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writer) error { func (a *testFormProcessor) ProcessField(ctx context.Context, formName string, writer *multipart.Writer) error {
if formName != "token" { if formName != "token" {
return errors.New("illegal field") return errors.New("illegal field")
} }
return nil return nil
} }
func (a *testFormProcessor) Finalize() error { func (a *testFormProcessor) Finalize(ctx context.Context) error {
return nil return nil
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment