Commit ac865e89 authored by Mohammad Gufran's avatar Mohammad Gufran Committed by Matt Holt

fastcgi: Add support for SRV upstreams (#1870)

parent b7167803
...@@ -20,6 +20,7 @@ package fastcgi ...@@ -20,6 +20,7 @@ package fastcgi
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
...@@ -107,7 +108,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -107,7 +108,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
// Connect to FastCGI gateway // Connect to FastCGI gateway
network, address := parseAddress(rule.Address()) address, err := rule.Address()
if err != nil {
return http.StatusBadGateway, err
}
network, address := parseAddress(address)
ctx := context.Background() ctx := context.Background()
if rule.ConnectTimeout > 0 { if rule.ConnectTimeout > 0 {
...@@ -381,7 +386,7 @@ type Rule struct { ...@@ -381,7 +386,7 @@ type Rule struct {
type balancer interface { type balancer interface {
// Address picks an upstream address from the // Address picks an upstream address from the
// underlying load balancer. // underlying load balancer.
Address() string Address() (string, error)
} }
// roundRobin is a round robin balancer for fastcgi upstreams. // roundRobin is a round robin balancer for fastcgi upstreams.
...@@ -393,9 +398,34 @@ type roundRobin struct { ...@@ -393,9 +398,34 @@ type roundRobin struct {
addresses []string addresses []string
} }
func (r *roundRobin) Address() string { func (r *roundRobin) Address() (string, error) {
index := atomic.AddInt64(&r.index, 1) % int64(len(r.addresses)) index := atomic.AddInt64(&r.index, 1) % int64(len(r.addresses))
return r.addresses[index] return r.addresses[index], nil
}
// srvResolver is a private interface used to abstract
// the DNS resolver. It is mainly used to facilitate testing.
type srvResolver interface {
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
}
// srv is a service locator for fastcgi upstreams
type srv struct {
resolver srvResolver
service string
}
// Address looks up the service and returns the address:port
// from first result in resolved list.
// No explicit balancing is required because net.LookupSRV
// sorts the results by priority and randomizes within priority.
func (s *srv) Address() (string, error) {
_, addrs, err := s.resolver.LookupSRV(context.Background(), "", "", s.service)
if err != nil {
return "", err
}
return fmt.Sprintf("%s:%d", strings.TrimRight(addrs[0].Target, "."), addrs[0].Port), nil
} }
// canSplit checks if path can split into two based on rule.SplitPath. // canSplit checks if path can split into two based on rule.SplitPath.
......
...@@ -84,11 +84,15 @@ func TestRuleParseAddress(t *testing.T) { ...@@ -84,11 +84,15 @@ func TestRuleParseAddress(t *testing.T) {
} }
for _, entry := range getClientTestTable { for _, entry := range getClientTestTable {
if actualnetwork, _ := parseAddress(entry.rule.Address()); actualnetwork != entry.expectednetwork { addr, err := entry.rule.Address()
t.Errorf("Unexpected network for address string %v. Got %v, expected %v", entry.rule.Address(), actualnetwork, entry.expectednetwork) if err != nil {
t.Errorf("Unexpected error in retrieving address: %s", err.Error())
}
if actualnetwork, _ := parseAddress(addr); actualnetwork != entry.expectednetwork {
t.Errorf("Unexpected network for address string %v. Got %v, expected %v", addr, actualnetwork, entry.expectednetwork)
} }
if _, actualaddress := parseAddress(entry.rule.Address()); actualaddress != entry.expectedaddress { if _, actualaddress := parseAddress(addr); actualaddress != entry.expectedaddress {
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address(), actualaddress, entry.expectedaddress) t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", addr, actualaddress, entry.expectedaddress)
} }
} }
} }
...@@ -365,7 +369,10 @@ func TestBalancer(t *testing.T) { ...@@ -365,7 +369,10 @@ func TestBalancer(t *testing.T) {
for i, test := range tests { for i, test := range tests {
b := address(test...) b := address(test...)
for _, host := range test { for _, host := range test {
a := b.Address() a, err := b.Address()
if err != nil {
t.Errorf("Unexpected error in trying to retrieve address: %s", err.Error())
}
if a != host { if a != host {
t.Errorf("Test %d: expected %s, found %s", i, host, a) t.Errorf("Test %d: expected %s, found %s", i, host, a)
} }
......
...@@ -16,8 +16,11 @@ package fastcgi ...@@ -16,8 +16,11 @@ package fastcgi
import ( import (
"errors" "errors"
"fmt"
"net"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -76,8 +79,14 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -76,8 +79,14 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
Root: absRoot, Root: absRoot,
Path: args[0], Path: args[0],
} }
upstreams := []string{args[1]} upstreams := []string{args[1]}
srvUpstream := false
if strings.HasPrefix(upstreams[0], "srv://") {
srvUpstream = true
}
if len(args) == 3 { if len(args) == 3 {
if err := fastcgiPreset(args[2], &rule); err != nil { if err := fastcgiPreset(args[2], &rule); err != nil {
return rules, err return rules, err
...@@ -112,6 +121,10 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -112,6 +121,10 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
rule.IndexFiles = args rule.IndexFiles = args
case "upstream": case "upstream":
if srvUpstream {
return rules, c.Err("additional upstreams are not supported with SRV upstream")
}
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) != 1 { if len(args) != 1 {
...@@ -161,13 +174,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) { ...@@ -161,13 +174,32 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
} }
} }
rule.balancer = &roundRobin{addresses: upstreams, index: -1} if srvUpstream {
balancer, err := parseSRV(upstreams[0])
if err != nil {
return rules, c.Err("malformed service locator string: " + err.Error())
}
rule.balancer = balancer
} else {
rule.balancer = &roundRobin{addresses: upstreams, index: -1}
}
rules = append(rules, rule) rules = append(rules, rule)
} }
return rules, nil return rules, nil
} }
func parseSRV(locator string) (*srv, error) {
if locator[6:] == "" {
return nil, fmt.Errorf("%s does not include the host", locator)
}
return &srv{
service: locator[6:],
resolver: &net.Resolver{},
}, nil
}
// fastcgiPreset configures rule according to name. It returns an error if // fastcgiPreset configures rule according to name. It returns an error if
// name is not a recognized preset name. // name is not a recognized preset name.
func fastcgiPreset(name string, rule *Rule) error { func fastcgiPreset(name string, rule *Rule) error {
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
package fastcgi package fastcgi
import ( import (
"context"
"fmt" "fmt"
"net"
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -43,10 +45,14 @@ func TestSetup(t *testing.T) { ...@@ -43,10 +45,14 @@ func TestSetup(t *testing.T) {
if myHandler.Rules[0].Path != "/" { if myHandler.Rules[0].Path != "/" {
t.Errorf("Expected / as the Path") t.Errorf("Expected / as the Path")
} }
if myHandler.Rules[0].Address() != "127.0.0.1:9000" { addr, err := myHandler.Rules[0].Address()
t.Errorf("Expected 127.0.0.1:9000 as the Address") if err != nil {
t.Errorf("Unexpected error in trying to retrieve address: %s", err.Error())
} }
if addr != "127.0.0.1:9000" {
t.Errorf("Expected 127.0.0.1:9000 as the Address")
}
} }
func TestFastcgiParse(t *testing.T) { func TestFastcgiParse(t *testing.T) {
...@@ -106,9 +112,19 @@ func TestFastcgiParse(t *testing.T) { ...@@ -106,9 +112,19 @@ func TestFastcgiParse(t *testing.T) {
i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path) i, j, test.expectedFastcgiConfig[j].Path, actualFastcgiConfig.Path)
} }
if actualFastcgiConfig.Address() != test.expectedFastcgiConfig[j].Address() { actualAddr, err := actualFastcgiConfig.Address()
if err != nil {
t.Errorf("Test %d unexpected error in trying to retrieve %dth actual address: %s", i, j, err.Error())
}
expectedAddr, err := test.expectedFastcgiConfig[j].Address()
if err != nil {
t.Errorf("Test %d unexpected error in trying to retrieve %dth expected address: %s", i, j, err.Error())
}
if actualAddr != expectedAddr {
t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s", t.Errorf("Test %d expected %dth FastCGI Address to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].Address(), actualFastcgiConfig.Address()) i, j, expectedAddr, actualAddr)
} }
if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext { if actualFastcgiConfig.Ext != test.expectedFastcgiConfig[j].Ext {
...@@ -134,3 +150,75 @@ func TestFastcgiParse(t *testing.T) { ...@@ -134,3 +150,75 @@ func TestFastcgiParse(t *testing.T) {
} }
} }
func TestFastCGIResolveSRV(t *testing.T) {
tests := []struct {
inputFastcgiConfig string
locator string
target string
port uint16
shouldErr bool
}{
{
`fastcgi / srv://fpm.tcp.service.consul {
upstream yolo
}`,
"fpm.tcp.service.consul",
"127.0.0.1",
9000,
true,
},
{
`fastcgi / srv://fpm.tcp.service.consul`,
"fpm.tcp.service.consul",
"127.0.0.1",
9000,
false,
},
}
for i, test := range tests {
actualFastcgiConfigs, err := fastcgiParse(caddy.NewTestController("http", test.inputFastcgiConfig))
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
}
for _, actualFastcgiConfig := range actualFastcgiConfigs {
resolver, ok := (actualFastcgiConfig.balancer).(*srv)
if !ok {
t.Errorf("Test %d upstream balancer is not srv", i)
}
resolver.resolver = buildTestResolver(test.target, test.port)
addr, err := actualFastcgiConfig.Address()
if err != nil {
t.Errorf("Test %d failed to retrieve upstream address. %s", i, err.Error())
}
expectedAddr := fmt.Sprintf("%s:%d", test.target, test.port)
if addr != expectedAddr {
t.Errorf("Test %d expected upstream address to be %s, got %s", i, expectedAddr, addr)
}
}
}
}
func buildTestResolver(target string, port uint16) srvResolver {
return &testSRVResolver{target, port}
}
type testSRVResolver struct {
target string
port uint16
}
func (r *testSRVResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
return "", []*net.SRV{
{Target: r.target,
Port: r.port,
Priority: 1,
Weight: 1}}, nil
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment