Commit 5e0bfe5b authored by Jacob Vosmaer's avatar Jacob Vosmaer

Assert that streams do not hang gitaly servers

parent b18f84f2
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path" "path"
"strings" "strings"
"testing" "testing"
"time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
...@@ -66,6 +69,41 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { ...@@ -66,6 +69,41 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) {
assert.Equal(t, expectedContent, body, "GET %q: response body", resource) assert.Equal(t, expectedContent, body, "GET %q: response body", resource)
} }
func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) {
apiResponse := gitOkBody(t)
gitalyServer, socketPath := startGitalyServer(t, codes.OK)
defer gitalyServer.Stop()
gitalyAddress := "unix://" + socketPath
apiResponse.GitalyAddress = gitalyAddress
ts := testAuthServer(nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := "/gitlab-org/gitlab-test.git/info/refs?service=git-upload-pack"
resp, err := http.Get(ws.URL + resource)
require.NoError(t, err)
// This causes the server stream to be interrupted instead of consumed entirely.
resp.Body.Close()
done := make(chan struct{})
go func() {
gitalyServer.WaitGroup.Wait()
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
}
func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) { func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
apiResponse := gitOkBody(t) apiResponse := gitOkBody(t)
...@@ -100,6 +138,45 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -100,6 +138,45 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) {
testhelper.AssertResponseHeader(t, resp, "Content-Type", "application/x-git-receive-pack-result") testhelper.AssertResponseHeader(t, resp, "Content-Type", "application/x-git-receive-pack-result")
} }
func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) {
apiResponse := gitOkBody(t)
gitalyServer, socketPath := startGitalyServer(t, codes.OK)
defer gitalyServer.Stop()
apiResponse.GitalyAddress = "unix://" + socketPath
ts := testAuthServer(nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := "/gitlab-org/gitlab-test.git/git-receive-pack"
resp, err := http.Post(
ws.URL+resource,
"application/x-git-receive-pack-request",
bytes.NewReader(testhelper.GitalyReceivePackResponseMock),
)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode, "POST %q", resource)
// This causes the server stream to be interrupted instead of consumed entirely.
resp.Body.Close()
done := make(chan struct{})
go func() {
gitalyServer.WaitGroup.Wait()
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
}
func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
apiResponse := gitOkBody(t) apiResponse := gitOkBody(t)
...@@ -137,6 +214,45 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { ...@@ -137,6 +214,45 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) {
} }
} }
func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) {
apiResponse := gitOkBody(t)
gitalyServer, socketPath := startGitalyServer(t, codes.OK)
defer gitalyServer.Stop()
apiResponse.GitalyAddress = "unix://" + socketPath
ts := testAuthServer(nil, 200, apiResponse)
defer ts.Close()
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
resource := "/gitlab-org/gitlab-test.git/git-upload-pack"
resp, err := http.Post(
ws.URL+resource,
"application/x-git-upload-pack-request",
bytes.NewReader(testhelper.GitalyUploadPackResponseMock),
)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode, "POST %q", resource)
// This causes the server stream to be interrupted instead of consumed entirely.
resp.Body.Close()
done := make(chan struct{})
go func() {
gitalyServer.WaitGroup.Wait()
close(done)
}()
select {
case <-done:
return
case <-time.After(10 * time.Second):
t.Fatal("time out waiting for gitaly handler to return")
}
}
func TestGetInfoRefsHandledLocallyDueToEmptyGitalySocketPath(t *testing.T) { func TestGetInfoRefsHandledLocallyDueToEmptyGitalySocketPath(t *testing.T) {
gitalyServer, _ := startGitalyServer(t, codes.OK) gitalyServer, _ := startGitalyServer(t, codes.OK)
defer gitalyServer.Stop() defer gitalyServer.Stop()
...@@ -199,7 +315,12 @@ func TestPostUploadPackHandledLocallyDueToEmptyGitalySocketPath(t *testing.T) { ...@@ -199,7 +315,12 @@ func TestPostUploadPackHandledLocallyDueToEmptyGitalySocketPath(t *testing.T) {
testhelper.AssertResponseHeader(t, resp, "Content-Type", "application/x-git-upload-pack-result") testhelper.AssertResponseHeader(t, resp, "Content-Type", "application/x-git-upload-pack-result")
} }
func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*grpc.Server, string) { type combinedServer struct {
*grpc.Server
*testhelper.GitalyTestServer
}
func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*combinedServer, string) {
socketPath := path.Join(scratchDir, fmt.Sprintf("gitaly-%d.sock", rand.Int())) socketPath := path.Join(scratchDir, fmt.Sprintf("gitaly-%d.sock", rand.Int()))
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err) t.Fatal(err)
...@@ -208,9 +329,10 @@ func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*grpc.Server, ...@@ -208,9 +329,10 @@ func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*grpc.Server,
listener, err := net.Listen("unix", socketPath) listener, err := net.Listen("unix", socketPath)
require.NoError(t, err) require.NoError(t, err)
pb.RegisterSmartHTTPServer(server, testhelper.NewGitalyServer(finalMessageCode)) gitalyServer := testhelper.NewGitalyServer(finalMessageCode)
pb.RegisterSmartHTTPServer(server, gitalyServer)
go server.Serve(listener) go server.Serve(listener)
return server, socketPath return &combinedServer{Server: server, GitalyTestServer: gitalyServer}, socketPath
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"log" "log"
"path" "path"
"strings" "strings"
"sync"
pb "gitlab.com/gitlab-org/gitaly-proto/go" pb "gitlab.com/gitlab-org/gitaly-proto/go"
...@@ -16,12 +17,14 @@ import ( ...@@ -16,12 +17,14 @@ import (
type GitalyTestServer struct { type GitalyTestServer struct {
finalMessageCode codes.Code finalMessageCode codes.Code
sync.WaitGroup
} }
const GitalyInfoRefsResponseMock = "Mock Gitaly InfoRefsResponse data" var (
GitalyInfoRefsResponseMock = strings.Repeat("Mock Gitaly InfoRefsResponse data", 100000)
var GitalyReceivePackResponseMock []byte GitalyReceivePackResponseMock []byte
var GitalyUploadPackResponseMock []byte GitalyUploadPackResponseMock []byte
)
func init() { func init() {
var err error var err error
...@@ -38,21 +41,30 @@ func NewGitalyServer(finalMessageCode codes.Code) *GitalyTestServer { ...@@ -38,21 +41,30 @@ func NewGitalyServer(finalMessageCode codes.Code) *GitalyTestServer {
} }
func (s *GitalyTestServer) InfoRefsUploadPack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsUploadPackServer) error { func (s *GitalyTestServer) InfoRefsUploadPack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsUploadPackServer) error {
s.WaitGroup.Add(1)
defer s.WaitGroup.Done()
if err := validateRepository(in.GetRepository()); err != nil { if err := validateRepository(in.GetRepository()); err != nil {
return err return err
} }
response := &pb.InfoRefsResponse{ nSends, err := sendBytes([]byte(GitalyInfoRefsResponseMock), 100, func(p []byte) error {
Data: []byte(GitalyInfoRefsResponseMock), return stream.Send(&pb.InfoRefsResponse{Data: p})
} })
if err := stream.Send(response); err != nil { if err != nil {
return err return err
} }
if nSends <= 1 {
panic("should have sent more than one message")
}
return s.finalError() return s.finalError()
} }
func (s *GitalyTestServer) InfoRefsReceivePack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsReceivePackServer) error { func (s *GitalyTestServer) InfoRefsReceivePack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsReceivePackServer) error {
s.WaitGroup.Add(1)
defer s.WaitGroup.Done()
if err := validateRepository(in.GetRepository()); err != nil { if err := validateRepository(in.GetRepository()); err != nil {
return err return err
} }
...@@ -68,6 +80,9 @@ func (s *GitalyTestServer) InfoRefsReceivePack(in *pb.InfoRefsRequest, stream pb ...@@ -68,6 +80,9 @@ func (s *GitalyTestServer) InfoRefsReceivePack(in *pb.InfoRefsRequest, stream pb
} }
func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackServer) error { func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackServer) error {
s.WaitGroup.Add(1)
defer s.WaitGroup.Done()
req, err := stream.Recv() req, err := stream.Recv()
if err != nil { if err != nil {
return err return err
...@@ -77,17 +92,13 @@ func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackSe ...@@ -77,17 +92,13 @@ func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackSe
if err := validateRepository(req.GetRepository()); err != nil { if err := validateRepository(req.GetRepository()); err != nil {
return err return err
} }
response := &pb.PostReceivePackResponse{
Data: []byte(strings.Join([]string{ data := []byte(strings.Join([]string{
repo.GetPath(), repo.GetPath(),
repo.GetStorageName(), repo.GetStorageName(),
repo.GetRelativePath(), repo.GetRelativePath(),
req.GlId, req.GlId,
}, "\000") + "\000"), }, "\000") + "\000")
}
if err := stream.Send(response); err != nil {
return err
}
// The body of the request starts in the second message // The body of the request starts in the second message
for { for {
...@@ -99,18 +110,25 @@ func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackSe ...@@ -99,18 +110,25 @@ func (s *GitalyTestServer) PostReceivePack(stream pb.SmartHTTP_PostReceivePackSe
break break
} }
response := &pb.PostReceivePackResponse{ // We want to echo the request data back
Data: req.GetData(), data = append(data, req.GetData()...)
} }
if err := stream.Send(response); err != nil {
return err nSends, err := sendBytes(data, 100, func(p []byte) error {
} return stream.Send(&pb.PostReceivePackResponse{Data: p})
})
if nSends <= 1 {
panic("should have sent more than one message")
} }
return s.finalError() return s.finalError()
} }
func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServer) error { func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServer) error {
s.WaitGroup.Add(1)
defer s.WaitGroup.Done()
req, err := stream.Recv() req, err := stream.Recv()
if err != nil { if err != nil {
return err return err
...@@ -120,16 +138,12 @@ func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServ ...@@ -120,16 +138,12 @@ func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServ
if err := validateRepository(req.GetRepository()); err != nil { if err := validateRepository(req.GetRepository()); err != nil {
return err return err
} }
response := &pb.PostUploadPackResponse{
Data: []byte(strings.Join([]string{ data := []byte(strings.Join([]string{
repo.GetPath(), repo.GetPath(),
repo.GetStorageName(), repo.GetStorageName(),
repo.GetRelativePath(), repo.GetRelativePath(),
}, "\000") + "\000"), }, "\000") + "\000")
}
if err := stream.Send(response); err != nil {
return err
}
// The body of the request starts in the second message // The body of the request starts in the second message
for { for {
...@@ -141,15 +155,36 @@ func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServ ...@@ -141,15 +155,36 @@ func (s *GitalyTestServer) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServ
break break
} }
response := &pb.PostUploadPackResponse{ data = append(data, req.GetData()...)
Data: req.GetData(), }
nSends, err := sendBytes(data, 100, func(p []byte) error {
return stream.Send(&pb.PostUploadPackResponse{Data: p})
})
if nSends <= 1 {
panic("should have sent more than one message")
}
return s.finalError()
}
// sendBytes returns the number of times the 'sender' function was called and an error.
func sendBytes(data []byte, chunkSize int, sender func([]byte) error) (int, error) {
i := 0
for ; len(data) > 0; i++ {
n := chunkSize
if n > len(data) {
n = len(data)
} }
if err := stream.Send(response); err != nil {
return err if err := sender(data[:n]); err != nil {
return i, err
} }
data = data[n:]
} }
return s.finalError() return i, nil
} }
func (s *GitalyTestServer) finalError() error { func (s *GitalyTestServer) finalError() error {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment