Commit b90ed477 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Re-use client.Dial from gitaly

parent 3aef5f9a
package gitaly
import (
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"gitlab.com/gitlab-org/gitaly/auth"
gitalyclient "gitlab.com/gitlab-org/gitaly/client"
"google.golang.org/grpc"
)
......@@ -46,6 +42,14 @@ func NewBlobClient(server Server) (*BlobClient, error) {
}
func getOrCreateConnection(server Server) (*grpc.ClientConn, error) {
cache.RLock()
conn := cache.connections[server]
cache.RUnlock()
if conn != nil {
return conn, nil
}
cache.Lock()
defer cache.Unlock()
......@@ -73,48 +77,9 @@ func CloseConnections() {
}
func newConnection(server Server) (*grpc.ClientConn, error) {
network, addr, err := parseAddress(server.Address)
if err != nil {
return nil, err
}
connOpts := []grpc.DialOption{
grpc.WithInsecure(), // Since we're connecting to Gitaly over UNIX, we don't need to use TLS credentials.
grpc.WithDialer(func(a string, _ time.Duration) (net.Conn, error) {
return net.Dial(network, a)
}),
connOpts := append(gitalyclient.DefaultDialOpts,
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentials(server.Token)),
}
conn, err := grpc.Dial(addr, connOpts...)
if err != nil {
return nil, err
}
)
return conn, nil
}
func parseAddress(rawAddress string) (network, addr string, err error) {
// Parsing unix:// URL's with url.Parse does not give the result we want
// so we do it manually.
for _, prefix := range []string{"unix://", "unix:"} {
if strings.HasPrefix(rawAddress, prefix) {
return "unix", strings.TrimPrefix(rawAddress, prefix), nil
}
}
u, err := url.Parse(rawAddress)
if err != nil {
return "", "", err
}
if u.Scheme != "tcp" {
return "", "", fmt.Errorf("unknown scheme: %q", rawAddress)
}
if u.Host == "" {
return "", "", fmt.Errorf("network tcp requires host: %q", rawAddress)
}
if u.Path != "" {
return "", "", fmt.Errorf("network tcp should have no path: %q", rawAddress)
}
return "tcp", u.Host, nil
return gitalyclient.Dial(server.Address, connOpts)
}
package gitaly
import (
"testing"
)
func TestParseAddress(t *testing.T) {
testCases := []struct {
raw string
network string
addr string
invalid bool
}{
{raw: "unix:/foo/bar.socket", network: "unix", addr: "/foo/bar.socket"},
{raw: "unix:///foo/bar.socket", network: "unix", addr: "/foo/bar.socket"},
// Mainly for test purposes we explicitly want to support relative paths
{raw: "unix://foo/bar.socket", network: "unix", addr: "foo/bar.socket"},
{raw: "unix:foo/bar.socket", network: "unix", addr: "foo/bar.socket"},
{raw: "tcp://1.2.3.4", network: "tcp", addr: "1.2.3.4"},
{raw: "tcp://1.2.3.4:567", network: "tcp", addr: "1.2.3.4:567"},
{raw: "tcp://foobar", network: "tcp", addr: "foobar"},
{raw: "tcp://foobar:567", network: "tcp", addr: "foobar:567"},
{raw: "tcp://1.2.3.4/foo/bar.socket", invalid: true},
{raw: "tcp:///foo/bar.socket", invalid: true},
{raw: "tcp:/foo/bar.socket", invalid: true},
}
for _, tc := range testCases {
network, addr, err := parseAddress(tc.raw)
if err == nil && tc.invalid {
t.Errorf("%v: expected error, got none", tc)
} else if err != nil && !tc.invalid {
t.Errorf("%v: parse error: %v", tc, err)
continue
}
if tc.invalid {
continue
}
if tc.network != network {
t.Errorf("%v: expected %q, got %q", tc, tc.network, network)
}
if tc.addr != addr {
t.Errorf("%v: expected %q, got %q", tc, tc.addr, addr)
}
}
}
package client
import (
"fmt"
"net"
"net/url"
"strings"
"time"
"google.golang.org/grpc"
)
// DefaultDialOpts hold the default DialOptions for connection to Gitaly over UNIX-socket
var DefaultDialOpts = []grpc.DialOption{
grpc.WithInsecure(),
}
// Dial gitaly
func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, error) {
network, addr, err := parseAddress(rawAddress)
if err != nil {
return nil, err
}
connOpts = append(connOpts,
grpc.WithDialer(func(a string, _ time.Duration) (net.Conn, error) {
return net.Dial(network, a)
}))
conn, err := grpc.Dial(addr, connOpts...)
if err != nil {
return nil, err
}
return conn, nil
}
func parseAddress(rawAddress string) (network, addr string, err error) {
// Parsing unix:// URL's with url.Parse does not give the result we want
// so we do it manually.
for _, prefix := range []string{"unix://", "unix:"} {
if strings.HasPrefix(rawAddress, prefix) {
return "unix", strings.TrimPrefix(rawAddress, prefix), nil
}
}
u, err := url.Parse(rawAddress)
if err != nil {
return "", "", err
}
if u.Scheme != "tcp" {
return "", "", fmt.Errorf("unknown scheme: %q", rawAddress)
}
if u.Host == "" {
return "", "", fmt.Errorf("network tcp requires host: %q", rawAddress)
}
if u.Path != "" {
return "", "", fmt.Errorf("network tcp should have no path: %q", rawAddress)
}
return "tcp", u.Host, nil
}
package client
import (
"io"
"gitlab.com/gitlab-org/gitaly/streamio"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
// ReceivePack proxies an SSH git-receive-pack (git push) session to Gitaly
func ReceivePack(ctx context.Context, conn *grpc.ClientConn, stdin io.Reader, stdout, stderr io.Writer, req *pb.SSHReceivePackRequest) (int32, error) {
ctx2, cancel := context.WithCancel(ctx)
defer cancel()
ssh := pb.NewSSHServiceClient(conn)
stream, err := ssh.SSHReceivePack(ctx2)
if err != nil {
return 0, err
}
if err = stream.Send(req); err != nil {
return 0, err
}
inWriter := streamio.NewWriter(func(p []byte) error {
return stream.Send(&pb.SSHReceivePackRequest{Stdin: p})
})
return streamHandler(func() (stdoutStderrResponse, error) {
return stream.Recv()
}, func(errC chan error) {
_, errRecv := io.Copy(inWriter, stdin)
stream.CloseSend()
errC <- errRecv
}, stdout, stderr)
}
package client
import (
"fmt"
"io"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
)
type stdoutStderrResponse interface {
GetExitStatus() *pb.ExitStatus
GetStderr() []byte
GetStdout() []byte
}
func streamHandler(recv func() (stdoutStderrResponse, error), send func(chan error), stdout, stderr io.Writer) (int32, error) {
var (
exitStatus int32
err error
resp stdoutStderrResponse
)
errC := make(chan error, 1)
go send(errC)
for {
resp, err = recv()
if err != nil {
break
}
if resp.GetExitStatus() != nil {
exitStatus = resp.GetExitStatus().GetValue()
}
if len(resp.GetStderr()) > 0 {
if _, err = stderr.Write(resp.GetStderr()); err != nil {
break
}
}
if len(resp.GetStdout()) > 0 {
if _, err = stdout.Write(resp.GetStdout()); err != nil {
break
}
}
}
if err == io.EOF {
err = nil
}
if err != nil {
return exitStatus, err
}
select {
case errSend := <-errC:
if errSend != nil {
// This should not happen
errSend = fmt.Errorf("stdin send error: %v", errSend)
}
return exitStatus, errSend
default:
return exitStatus, nil
}
}
package client
import (
"io"
"gitlab.com/gitlab-org/gitaly/streamio"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
// UploadPack proxies an SSH git-upload-pack (git fetch) session to Gitaly
func UploadPack(ctx context.Context, conn *grpc.ClientConn, stdin io.Reader, stdout, stderr io.Writer, req *pb.SSHUploadPackRequest) (int32, error) {
ctx2, cancel := context.WithCancel(ctx)
defer cancel()
ssh := pb.NewSSHServiceClient(conn)
stream, err := ssh.SSHUploadPack(ctx2)
if err != nil {
return 0, err
}
if err = stream.Send(req); err != nil {
return 0, err
}
inWriter := streamio.NewWriter(func(p []byte) error {
return stream.Send(&pb.SSHUploadPackRequest{Stdin: p})
})
return streamHandler(func() (stdoutStderrResponse, error) {
return stream.Recv()
}, func(errC chan error) {
_, errRecv := io.Copy(inWriter, stdin)
stream.CloseSend()
errC <- errRecv
}, stdout, stderr)
}
......@@ -161,6 +161,12 @@
"version": "v0.14.0",
"versionExact": "v0.14.0"
},
{
"checksumSHA1": "qlzYmQ21XX/voiKWHDWUZ3lybGQ=",
"path": "gitlab.com/gitlab-org/gitaly/client",
"revision": "871c07b758cf13892972089e1357c5a28bfc40d9",
"revisionTime": "2017-09-22T12:56:31Z"
},
{
"checksumSHA1": "sdUF3j5MaQ9Tjc2dGHqc/toQxyk=",
"path": "gitlab.com/gitlab-org/gitaly/streamio",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment