Commit 5d1752d1 authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Add support for Gitaly feature flags

parent efd1d567
...@@ -52,11 +52,11 @@ func realGitalyOkBody(t *testing.T) *api.Response { ...@@ -52,11 +52,11 @@ func realGitalyOkBody(t *testing.T) *api.Response {
} }
func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error { func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
namespace, err := gitaly.NewNamespaceClient(apiResponse.GitalyServer) ctx, namespace, err := gitaly.NewNamespaceClient(context.Background(), apiResponse.GitalyServer)
if err != nil { if err != nil {
return err return err
} }
repository, err := gitaly.NewRepositoryClient(apiResponse.GitalyServer) ctx, repository, err := gitaly.NewRepositoryClient(ctx, apiResponse.GitalyServer)
if err != nil { if err != nil {
return err return err
} }
...@@ -66,7 +66,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error { ...@@ -66,7 +66,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
StorageName: apiResponse.Repository.StorageName, StorageName: apiResponse.Repository.StorageName,
Name: apiResponse.Repository.RelativePath, Name: apiResponse.Repository.RelativePath,
} }
_, err = namespace.RemoveNamespace(context.Background(), rmNsReq) _, err = namespace.RemoveNamespace(ctx, rmNsReq)
if err != nil { if err != nil {
return err return err
} }
...@@ -76,7 +76,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error { ...@@ -76,7 +76,7 @@ func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
Url: "https://gitlab.com/gitlab-org/gitlab-test.git", Url: "https://gitlab.com/gitlab-org/gitlab-test.git",
} }
_, err = repository.CreateRepositoryFromURL(context.Background(), createReq) _, err = repository.CreateRepositoryFromURL(ctx, createReq)
return err return err
} }
......
...@@ -64,6 +64,23 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { ...@@ -64,6 +64,23 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
apiResponse := gitOkBody(t) apiResponse := gitOkBody(t)
apiResponse.GitalyServer.Address = gitalyAddress apiResponse.GitalyServer.Address = gitalyAddress
goodMetadata := map[string]string{
"gitaly-feature-foobar": "true",
"gitaly-feature-bazqux": "false",
}
badMetadata := map[string]string{
"bad-metadata": "is blocked",
}
features := make(map[string]string)
for k, v := range goodMetadata {
features[k] = v
}
for k, v := range badMetadata {
features[k] = v
}
apiResponse.GitalyServer.Features = features
testCases := []struct { testCases := []struct {
showAllRefs bool showAllRefs bool
gitRpc string gitRpc string
...@@ -106,6 +123,18 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { ...@@ -106,6 +123,18 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
require.Equal(t, tc.gitRpc, bodySplit[1]) require.Equal(t, tc.gitRpc, bodySplit[1])
require.Equal(t, string(testhelper.GitalyInfoRefsResponseMock), bodySplit[2], "GET %q: response body", resource) require.Equal(t, string(testhelper.GitalyInfoRefsResponseMock), bodySplit[2], "GET %q: response body", resource)
md := gitalyServer.LastIncomingMetadata
for k, v := range goodMetadata {
actual := md[k]
require.Len(t, actual, 1, "number of metadata values for %v", k)
require.Equal(t, v, actual[0], "value for %v", k)
}
for k := range badMetadata {
actual := md[k]
require.Empty(t, actual, "metadata for bad key %v", k)
}
}) })
} }
} }
......
...@@ -134,7 +134,7 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string ...@@ -134,7 +134,7 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string
func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gitalypb.GetArchiveRequest_Format) (io.Reader, error) { func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gitalypb.GetArchiveRequest_Format) (io.Reader, error) {
var request *gitalypb.GetArchiveRequest var request *gitalypb.GetArchiveRequest
c, err := gitaly.NewRepositoryClient(params.GitalyServer) ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -154,7 +154,7 @@ func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gital ...@@ -154,7 +154,7 @@ func handleArchiveWithGitaly(r *http.Request, params archiveParams, format gital
} }
} }
return c.ArchiveReader(r.Context(), request) return c.ArchiveReader(ctx, request)
} }
func setArchiveHeaders(w http.ResponseWriter, format gitalypb.GetArchiveRequest_Format, archiveFilename string) { func setArchiveHeaders(w http.ResponseWriter, format gitalypb.GetArchiveRequest_Format, archiveFilename string) {
......
...@@ -26,13 +26,13 @@ func (b *blob) Inject(w http.ResponseWriter, r *http.Request, sendData string) { ...@@ -26,13 +26,13 @@ func (b *blob) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
return return
} }
blobClient, err := gitaly.NewBlobClient(params.GitalyServer) ctx, blobClient, err := gitaly.NewBlobClient(r.Context(), params.GitalyServer)
if err != nil { if err != nil {
helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err)) helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
return return
} }
if err := blobClient.SendBlob(r.Context(), w, &params.GetBlobRequest); err != nil { if err := blobClient.SendBlob(ctx, w, &params.GetBlobRequest); err != nil {
helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err)) helper.Fail500(w, r, fmt.Errorf("blob.GetBlob: %v", err))
return return
} }
......
...@@ -32,13 +32,13 @@ func (d *diff) Inject(w http.ResponseWriter, r *http.Request, sendData string) { ...@@ -32,13 +32,13 @@ func (d *diff) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
return return
} }
diffClient, err := gitaly.NewDiffClient(params.GitalyServer) ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
if err != nil { if err != nil {
helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err)) helper.Fail500(w, r, fmt.Errorf("diff.RawDiff: %v", err))
return return
} }
if err := diffClient.SendRawDiff(r.Context(), w, request); err != nil { if err := diffClient.SendRawDiff(ctx, w, request); err != nil {
helper.LogError( helper.LogError(
r, r,
&copyError{fmt.Errorf("diff.RawDiff: request=%v, err=%v", request, err)}, &copyError{fmt.Errorf("diff.RawDiff: request=%v, err=%v", request, err)},
......
...@@ -32,13 +32,13 @@ func (p *patch) Inject(w http.ResponseWriter, r *http.Request, sendData string) ...@@ -32,13 +32,13 @@ func (p *patch) Inject(w http.ResponseWriter, r *http.Request, sendData string)
return return
} }
diffClient, err := gitaly.NewDiffClient(params.GitalyServer) ctx, diffClient, err := gitaly.NewDiffClient(r.Context(), params.GitalyServer)
if err != nil { if err != nil {
helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err)) helper.Fail500(w, r, fmt.Errorf("diff.RawPatch: %v", err))
return return
} }
if err := diffClient.SendRawPatch(r.Context(), w, request); err != nil { if err := diffClient.SendRawPatch(ctx, w, request); err != nil {
helper.LogError( helper.LogError(
r, r,
&copyError{fmt.Errorf("diff.RawPatch: request=%v, err=%v", request, err)}, &copyError{fmt.Errorf("diff.RawPatch: request=%v, err=%v", request, err)},
......
...@@ -46,7 +46,7 @@ func handleGetInfoRefs(rw http.ResponseWriter, r *http.Request, a *api.Response) ...@@ -46,7 +46,7 @@ func handleGetInfoRefs(rw http.ResponseWriter, r *http.Request, a *api.Response)
} }
func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpResponseWriter, a *api.Response, rpc, gitProtocol, encoding string) error { func handleGetInfoRefsWithGitaly(ctx context.Context, responseWriter *HttpResponseWriter, a *api.Response, rpc, gitProtocol, encoding string) error {
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer) ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
if err != nil { if err != nil {
return fmt.Errorf("GetInfoRefsHandler: %v", err) return fmt.Errorf("GetInfoRefsHandler: %v", err)
} }
......
...@@ -20,12 +20,12 @@ func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response) ...@@ -20,12 +20,12 @@ func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response)
gitProtocol := r.Header.Get("Git-Protocol") gitProtocol := r.Header.Get("Git-Protocol")
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer) ctx, smarthttp, err := gitaly.NewSmartHTTPClient(r.Context(), a.GitalyServer)
if err != nil { if err != nil {
return fmt.Errorf("smarthttp.ReceivePack: %v", err) return fmt.Errorf("smarthttp.ReceivePack: %v", err)
} }
if err := smarthttp.ReceivePack(r.Context(), &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil { if err := smarthttp.ReceivePack(ctx, &a.Repository, a.GL_ID, a.GL_USERNAME, a.GL_REPOSITORY, a.GitConfigOptions, cr, cw, gitProtocol); err != nil {
return fmt.Errorf("smarthttp.ReceivePack: %v", err) return fmt.Errorf("smarthttp.ReceivePack: %v", err)
} }
......
...@@ -39,13 +39,13 @@ func (s *snapshot) Inject(w http.ResponseWriter, r *http.Request, sendData strin ...@@ -39,13 +39,13 @@ func (s *snapshot) Inject(w http.ResponseWriter, r *http.Request, sendData strin
return return
} }
c, err := gitaly.NewRepositoryClient(params.GitalyServer) ctx, c, err := gitaly.NewRepositoryClient(r.Context(), params.GitalyServer)
if err != nil { if err != nil {
helper.Fail500(w, r, fmt.Errorf("SendSnapshot: gitaly.NewRepositoryClient: %v", err)) helper.Fail500(w, r, fmt.Errorf("SendSnapshot: gitaly.NewRepositoryClient: %v", err))
return return
} }
reader, err := c.SnapshotReader(r.Context(), request) reader, err := c.SnapshotReader(ctx, request)
if err != nil { if err != nil {
helper.Fail500(w, r, fmt.Errorf("SendSnapshot: client.SnapshotReader: %v", err)) helper.Fail500(w, r, fmt.Errorf("SendSnapshot: client.SnapshotReader: %v", err))
return return
......
...@@ -33,7 +33,7 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e ...@@ -33,7 +33,7 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e
} }
func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error { func handleUploadPackWithGitaly(ctx context.Context, a *api.Response, clientRequest io.Reader, clientResponse io.Writer, gitProtocol string) error {
smarthttp, err := gitaly.NewSmartHTTPClient(a.GitalyServer) ctx, smarthttp, err := gitaly.NewSmartHTTPClient(ctx, a.GitalyServer)
if err != nil { if err != nil {
return fmt.Errorf("smarthttp.UploadPack: %v", err) return fmt.Errorf("smarthttp.UploadPack: %v", err)
} }
......
package gitaly package gitaly
import ( import (
"context"
"strings" "strings"
"sync" "sync"
...@@ -13,6 +14,7 @@ import ( ...@@ -13,6 +14,7 @@ import (
gitalyclient "gitlab.com/gitlab-org/gitaly/client" gitalyclient "gitlab.com/gitlab-org/gitaly/client"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
...@@ -21,17 +23,24 @@ import ( ...@@ -21,17 +23,24 @@ import (
type Server struct { type Server struct {
Address string `json:"address"` Address string `json:"address"`
Token string `json:"token"` Token string `json:"token"`
Features map[string]string `json:"features"`
}
type cacheKey struct{ address, token string }
func (server Server) cacheKey() cacheKey {
return cacheKey{address: server.Address, token: server.Token}
} }
type connectionsCache struct { type connectionsCache struct {
sync.RWMutex sync.RWMutex
connections map[Server]*grpc.ClientConn connections map[cacheKey]*grpc.ClientConn
} }
var ( var (
jsonUnMarshaler = jsonpb.Unmarshaler{AllowUnknownFields: true} jsonUnMarshaler = jsonpb.Unmarshaler{AllowUnknownFields: true}
cache = connectionsCache{ cache = connectionsCache{
connections: make(map[Server]*grpc.ClientConn), connections: make(map[cacheKey]*grpc.ClientConn),
} }
connectionsTotal = prometheus.NewCounterVec( connectionsTotal = prometheus.NewCounterVec(
...@@ -47,55 +56,69 @@ func init() { ...@@ -47,55 +56,69 @@ func init() {
prometheus.MustRegister(connectionsTotal) prometheus.MustRegister(connectionsTotal)
} }
func NewSmartHTTPClient(server Server) (*SmartHTTPClient, error) { func withOutgoingMetadata(ctx context.Context, features map[string]string) context.Context {
md := metadata.New(nil)
for k, v := range features {
if !strings.HasPrefix(k, "gitaly-feature-") {
continue
}
md.Append(k, v)
}
return metadata.NewOutgoingContext(ctx, md)
}
func NewSmartHTTPClient(ctx context.Context, server Server) (context.Context, *SmartHTTPClient, error) {
conn, err := getOrCreateConnection(server) conn, err := getOrCreateConnection(server)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
grpcClient := gitalypb.NewSmartHTTPServiceClient(conn) grpcClient := gitalypb.NewSmartHTTPServiceClient(conn)
return &SmartHTTPClient{grpcClient}, nil return withOutgoingMetadata(ctx, server.Features), &SmartHTTPClient{grpcClient}, nil
} }
func NewBlobClient(server Server) (*BlobClient, error) { func NewBlobClient(ctx context.Context, server Server) (context.Context, *BlobClient, error) {
conn, err := getOrCreateConnection(server) conn, err := getOrCreateConnection(server)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
grpcClient := gitalypb.NewBlobServiceClient(conn) grpcClient := gitalypb.NewBlobServiceClient(conn)
return &BlobClient{grpcClient}, nil return withOutgoingMetadata(ctx, server.Features), &BlobClient{grpcClient}, nil
} }
func NewRepositoryClient(server Server) (*RepositoryClient, error) { func NewRepositoryClient(ctx context.Context, server Server) (context.Context, *RepositoryClient, error) {
conn, err := getOrCreateConnection(server) conn, err := getOrCreateConnection(server)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
grpcClient := gitalypb.NewRepositoryServiceClient(conn) grpcClient := gitalypb.NewRepositoryServiceClient(conn)
return &RepositoryClient{grpcClient}, nil return withOutgoingMetadata(ctx, server.Features), &RepositoryClient{grpcClient}, nil
} }
// NewNamespaceClient is only used by the Gitaly integration tests at present // NewNamespaceClient is only used by the Gitaly integration tests at present
func NewNamespaceClient(server Server) (*NamespaceClient, error) { func NewNamespaceClient(ctx context.Context, server Server) (context.Context, *NamespaceClient, error) {
conn, err := getOrCreateConnection(server) conn, err := getOrCreateConnection(server)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
grpcClient := gitalypb.NewNamespaceServiceClient(conn) grpcClient := gitalypb.NewNamespaceServiceClient(conn)
return &NamespaceClient{grpcClient}, nil return withOutgoingMetadata(ctx, server.Features), &NamespaceClient{grpcClient}, nil
} }
func NewDiffClient(server Server) (*DiffClient, error) { func NewDiffClient(ctx context.Context, server Server) (context.Context, *DiffClient, error) {
conn, err := getOrCreateConnection(server) conn, err := getOrCreateConnection(server)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
grpcClient := gitalypb.NewDiffServiceClient(conn) grpcClient := gitalypb.NewDiffServiceClient(conn)
return &DiffClient{grpcClient}, nil return withOutgoingMetadata(ctx, server.Features), &DiffClient{grpcClient}, nil
} }
func getOrCreateConnection(server Server) (*grpc.ClientConn, error) { func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
key := server.cacheKey()
cache.RLock() cache.RLock()
conn := cache.connections[server] conn := cache.connections[key]
cache.RUnlock() cache.RUnlock()
if conn != nil { if conn != nil {
...@@ -105,7 +128,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) { ...@@ -105,7 +128,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
cache.Lock() cache.Lock()
defer cache.Unlock() defer cache.Unlock()
if conn := cache.connections[server]; conn != nil { if conn := cache.connections[key]; conn != nil {
return conn, nil return conn, nil
} }
...@@ -114,7 +137,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) { ...@@ -114,7 +137,7 @@ func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
return nil, err return nil, err
} }
cache.connections[server] = conn cache.connections[key] = conn
return conn, nil return conn, nil
} }
......
package gitaly
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)
func TestNewSmartHTTPClient(t *testing.T) {
ctx, _, err := NewSmartHTTPClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewBlobClient(t *testing.T) {
ctx, _, err := NewBlobClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewRepositoryClient(t *testing.T) {
ctx, _, err := NewRepositoryClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewNamespaceClient(t *testing.T) {
ctx, _, err := NewNamespaceClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func TestNewDiffClient(t *testing.T) {
ctx, _, err := NewDiffClient(context.Background(), serverFixture())
require.NoError(t, err)
testOutgoingMetadata(t, ctx)
}
func testOutgoingMetadata(t *testing.T, ctx context.Context) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok, "get metadata from context")
for k, v := range allowedFeatures() {
actual := md[k]
require.Len(t, actual, 1, "expect one value for %v", k)
require.Equal(t, v, actual[0], "value for %v", k)
}
for k := range badFeatureMetadata() {
require.Empty(t, md[k], "value for bad key %v", k)
}
}
func serverFixture() Server {
features := make(map[string]string)
for k, v := range allowedFeatures() {
features[k] = v
}
for k, v := range badFeatureMetadata() {
features[k] = v
}
return Server{Address: "tcp://localhost:123", Features: features}
}
func allowedFeatures() map[string]string {
return map[string]string{
"gitaly-feature-foo": "bar",
"gitaly-feature-qux": "baz",
}
}
func badFeatureMetadata() map[string]string {
return map[string]string{
"bad-metadata-1": "bad-value-1",
"bad-metadata-2": "bad-value-2",
}
}
...@@ -14,12 +14,14 @@ import ( ...@@ -14,12 +14,14 @@ import (
"gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/labkit/log"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
type GitalyTestServer struct { type GitalyTestServer struct {
finalMessageCode codes.Code finalMessageCode codes.Code
sync.WaitGroup sync.WaitGroup
LastIncomingMetadata metadata.MD
} }
var ( var (
...@@ -71,6 +73,11 @@ func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stre ...@@ -71,6 +73,11 @@ func (s *GitalyTestServer) InfoRefsUploadPack(in *gitalypb.InfoRefsRequest, stre
GitalyInfoRefsResponseMock, GitalyInfoRefsResponseMock,
}, "\000")) }, "\000"))
s.LastIncomingMetadata = nil
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
s.LastIncomingMetadata = md
}
return s.sendInfoRefs(stream, data) return s.sendInfoRefs(stream, data)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment