Commit b36a17b8 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'allow-to-access-remote-archive' into 'master'

Allow to access remote archive

See merge request !180
parents dea4edff 12035895
...@@ -5,8 +5,15 @@ import ( ...@@ -5,8 +5,15 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net"
"net/http"
"os" "os"
"strings"
"time"
"github.com/jfbus/httprs"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts" "gitlab.com/gitlab-org/gitlab-workhorse/internal/zipartifacts"
) )
...@@ -16,6 +23,67 @@ var Version = "unknown" ...@@ -16,6 +23,67 @@ var Version = "unknown"
var printVersion = flag.Bool("version", false, "Print version and exit") var printVersion = flag.Bool("version", false, "Print version and exit")
var httpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 10 * time.Second,
}).DialContext,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
},
}
func isURL(path string) bool {
return strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://")
}
func openHTTPArchive(archivePath string) (*zip.Reader, func()) {
scrubbedArchivePath := helper.ScrubURLParams(archivePath)
resp, err := httpClient.Get(archivePath)
if err != nil {
fatalError(fmt.Errorf("HTTP GET %q: %v", scrubbedArchivePath, err))
} else if resp.StatusCode == http.StatusNotFound {
notFoundError(fmt.Errorf("HTTP GET %q: not found", scrubbedArchivePath))
} else if resp.StatusCode != http.StatusOK {
fatalError(fmt.Errorf("HTTP GET %q: %d: %v", scrubbedArchivePath, resp.StatusCode, resp.Status))
}
rs := httprs.NewHttpReadSeeker(resp, httpClient)
archive, err := zip.NewReader(rs, resp.ContentLength)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", scrubbedArchivePath, err))
}
return archive, func() {
resp.Body.Close()
rs.Close()
}
}
func openFileArchive(archivePath string) (*zip.Reader, func()) {
archive, err := zip.OpenReader(archivePath)
if err != nil {
notFoundError(fmt.Errorf("open %q: %v", archivePath, err))
}
return &archive.Reader, func() {
archive.Close()
}
}
func openArchive(archivePath string) (*zip.Reader, func()) {
if isURL(archivePath) {
return openHTTPArchive(archivePath)
}
return openFileArchive(archivePath)
}
func main() { func main() {
flag.Parse() flag.Parse()
...@@ -25,32 +93,34 @@ func main() { ...@@ -25,32 +93,34 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if len(os.Args) != 3 { archivePath := os.Getenv("ARCHIVE_PATH")
fmt.Fprintf(os.Stderr, "Usage: %s FILE.ZIP ENTRY\n", progName) encodedFileName := os.Getenv("ENCODED_FILE_NAME")
if len(os.Args) != 1 || archivePath == "" || encodedFileName == "" {
fmt.Fprintf(os.Stderr, "Usage: %s\n", progName)
fmt.Fprintf(os.Stderr, "Env: ARCHIVE_PATH=https://path.to/archive.zip or /path/to/archive.zip\n")
fmt.Fprintf(os.Stderr, "Env: ENCODED_FILE_NAME=base64-encoded-file-name\n")
os.Exit(1) os.Exit(1)
} }
archiveFileName := os.Args[1] scrubbedArchivePath := helper.ScrubURLParams(archivePath)
fileName, err := zipartifacts.DecodeFileEntry(os.Args[2]) fileName, err := zipartifacts.DecodeFileEntry(encodedFileName)
if err != nil { if err != nil {
fatalError(fmt.Errorf("decode entry %q: %v", os.Args[2], err)) fatalError(fmt.Errorf("decode entry %q: %v", encodedFileName, err))
} }
archive, err := zip.OpenReader(archiveFileName) archive, cleanFn := openArchive(archivePath)
if err != nil { defer cleanFn()
notFoundError(fmt.Errorf("open %q: %v", archiveFileName, err))
}
defer archive.Close()
file := findFileInZip(fileName, &archive.Reader) file := findFileInZip(fileName, archive)
if file == nil { if file == nil {
notFoundError(fmt.Errorf("find %q in %q: not found", fileName, archiveFileName)) notFoundError(fmt.Errorf("find %q in %q: not found", fileName, scrubbedArchivePath))
} }
// Start decompressing the file // Start decompressing the file
reader, err := file.Open() reader, err := file.Open()
if err != nil { if err != nil {
fatalError(fmt.Errorf("open %q in %q: %v", fileName, archiveFileName, err)) fatalError(fmt.Errorf("open %q in %q: %v", fileName, scrubbedArchivePath, err))
} }
defer reader.Close() defer reader.Close()
...@@ -59,7 +129,7 @@ func main() { ...@@ -59,7 +129,7 @@ func main() {
} }
if _, err := io.Copy(os.Stdout, reader); err != nil { if _, err := io.Copy(os.Stdout, reader); err != nil {
fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, archiveFileName, err)) fatalError(fmt.Errorf("write %q from %q to stdout: %v", fileName, scrubbedArchivePath, err))
} }
} }
......
...@@ -55,13 +55,17 @@ func detectFileContentType(fileName string) string { ...@@ -55,13 +55,17 @@ func detectFileContentType(fileName string) string {
return contentType return contentType
} }
func unpackFileFromZip(archiveFileName, encodedFilename string, headers http.Header, output io.Writer) error { func unpackFileFromZip(archivePath, encodedFilename string, headers http.Header, output io.Writer) error {
fileName, err := zipartifacts.DecodeFileEntry(encodedFilename) fileName, err := zipartifacts.DecodeFileEntry(encodedFilename)
if err != nil { if err != nil {
return err return err
} }
catFile := exec.Command("gitlab-zip-cat", archiveFileName, encodedFilename) catFile := exec.Command("gitlab-zip-cat")
catFile.Env = []string{
"ARCHIVE_PATH=" + archivePath,
"ENCODED_FILE_NAME=" + encodedFilename,
}
catFile.Stderr = os.Stderr catFile.Stderr = os.Stderr
catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} catFile.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
stdout, err := catFile.StdoutPipe() stdout, err := catFile.StdoutPipe()
......
...@@ -8,17 +8,18 @@ import ( ...@@ -8,17 +8,18 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
) )
func testEntryServer(t *testing.T, archive string, entry string) *httptest.ResponseRecorder { func testEntryServer(t *testing.T, archive string, entry string) *httptest.ResponseRecorder {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/url/path", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" { require.Equal(t, "GET", r.Method)
t.Fatal("Expected GET request")
}
encodedEntry := base64.StdEncoding.EncodeToString([]byte(entry)) encodedEntry := base64.StdEncoding.EncodeToString([]byte(entry))
jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archive, encodedEntry) jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archive, encodedEntry)
...@@ -28,9 +29,7 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo ...@@ -28,9 +29,7 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo
}) })
httpRequest, err := http.NewRequest("GET", "/url/path", nil) httpRequest, err := http.NewRequest("GET", "/url/path", nil)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
response := httptest.NewRecorder() response := httptest.NewRecorder()
mux.ServeHTTP(response, httpRequest) mux.ServeHTTP(response, httpRequest)
return response return response
...@@ -38,18 +37,14 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo ...@@ -38,18 +37,14 @@ func testEntryServer(t *testing.T, archive string, entry string) *httptest.Respo
func TestDownloadingFromValidArchive(t *testing.T) { func TestDownloadingFromValidArchive(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads") tempFile, err := ioutil.TempFile("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer tempFile.Close() defer tempFile.Close()
defer os.Remove(tempFile.Name()) defer os.Remove(tempFile.Name())
archive := zip.NewWriter(tempFile) archive := zip.NewWriter(tempFile)
defer archive.Close() defer archive.Close()
fileInArchive, err := archive.Create("test.txt") fileInArchive, err := archive.Create("test.txt")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
fmt.Fprint(fileInArchive, "testtest") fmt.Fprint(fileInArchive, "testtest")
archive.Close() archive.Close()
...@@ -67,11 +62,43 @@ func TestDownloadingFromValidArchive(t *testing.T) { ...@@ -67,11 +62,43 @@ func TestDownloadingFromValidArchive(t *testing.T) {
testhelper.AssertResponseBody(t, response, "testtest") testhelper.AssertResponseBody(t, response, "testtest")
} }
func TestDownloadingFromValidHTTPArchive(t *testing.T) {
tempDir, err := ioutil.TempDir("", "uploads")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
f, err := os.Create(filepath.Join(tempDir, "archive.zip"))
require.NoError(t, err)
defer f.Close()
archive := zip.NewWriter(f)
defer archive.Close()
fileInArchive, err := archive.Create("test.txt")
require.NoError(t, err)
fmt.Fprint(fileInArchive, "testtest")
archive.Close()
f.Close()
fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer fileServer.Close()
response := testEntryServer(t, fileServer.URL+"/archive.zip", "test.txt")
testhelper.AssertResponseCode(t, response, 200)
testhelper.AssertResponseWriterHeader(t, response,
"Content-Type",
"text/plain; charset=utf-8")
testhelper.AssertResponseWriterHeader(t, response,
"Content-Disposition",
"attachment; filename=\"test.txt\"")
testhelper.AssertResponseBody(t, response, "testtest")
}
func TestDownloadingNonExistingFile(t *testing.T) { func TestDownloadingNonExistingFile(t *testing.T) {
tempFile, err := ioutil.TempFile("", "uploads") tempFile, err := ioutil.TempFile("", "uploads")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer tempFile.Close() defer tempFile.Close()
defer os.Remove(tempFile.Name()) defer os.Remove(tempFile.Name())
...@@ -92,3 +119,16 @@ func TestIncompleteApiResponse(t *testing.T) { ...@@ -92,3 +119,16 @@ func TestIncompleteApiResponse(t *testing.T) {
response := testEntryServer(t, "", "") response := testEntryServer(t, "", "")
testhelper.AssertResponseCode(t, response, 500) testhelper.AssertResponseCode(t, response, 500)
} }
func TestDownloadingFromNonExistingHTTPArchive(t *testing.T) {
tempDir, err := ioutil.TempDir("", "uploads")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
fileServer := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer fileServer.Close()
response := testEntryServer(t, fileServer.URL+"/not-existing-archive-file.zip", "test.txt")
testhelper.AssertResponseCode(t, response, 404)
}
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
const NginxResponseBufferHeader = "X-Accel-Buffering" const NginxResponseBufferHeader = "X-Accel-Buffering"
var scrubRegexp = regexp.MustCompile(`([\?&](?:private|authenticity|rss)[\-_]token)=[^&]*`) var scrubRegexp = regexp.MustCompile(`(?i)([\?&]((?:private|authenticity|rss)[\-_]token)|X-AMZ-Signature)=[^&]*`)
func Fail500(w http.ResponseWriter, r *http.Request, err error) { func Fail500(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Internal server error", 500) http.Error(w, "Internal server error", 500)
......
...@@ -127,6 +127,9 @@ func TestScrubURLParams(t *testing.T) { ...@@ -127,6 +127,9 @@ func TestScrubURLParams(t *testing.T) {
"?private-token=&authenticity_token=&bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]&bar", "?private-token=&authenticity_token=&bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]&bar",
"?private-token=foo&authenticity_token=bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]", "?private-token=foo&authenticity_token=bar": "?private-token=[FILTERED]&authenticity_token=[FILTERED]",
"?private_token=foo&authenticity-token=bar": "?private_token=[FILTERED]&authenticity-token=[FILTERED]", "?private_token=foo&authenticity-token=bar": "?private_token=[FILTERED]&authenticity-token=[FILTERED]",
"?X-AMZ-Signature=foo": "?X-AMZ-Signature=[FILTERED]",
"&X-AMZ-Signature=foo": "&X-AMZ-Signature=[FILTERED]",
"?x-amz-signature=foo": "?x-amz-signature=[FILTERED]",
} { } {
after := ScrubURLParams(before) after := ScrubURLParams(before)
assert.Equal(t, expected, after, "Scrubbing %q", before) assert.Equal(t, expected, after, "Scrubbing %q", before)
......
Copyright (c) 2015 Jean-François Bustarret
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
\ No newline at end of file
# httprs
A ReadSeeker for http.Response.Body
[![wercker status](https://app.wercker.com/status/b8ab18faefae7d1f88f9f23d642f0847/s/master "wercker status")](https://app.wercker.com/project/bykey/b8ab18faefae7d1f88f9f23d642f0847)
## Usage
```
import "github.com/jfbus/httprs"
resp, err := http.Get(url)
rs := httprs.NewHttpReadSeeker(resp)
defer rs.Close()
io.ReadFull(rs, buf) // reads the first bytes from the response
rs.Seek(1024, 0) // moves the position
io.ReadFull(rs, buf) // does an additional range request and reads the first bytes from the second response
```
if you use a specific http.Client :
```
rs := httprs.NewHttpReadSeeker(resp, client)
```
## Doc
See http://godoc.org/github.com/jfbus/httprs
## LICENSE
MIT - See LICENSE
\ No newline at end of file
/*
Package httprs provides a ReadSeeker for http.Response.Body.
Usage :
resp, err := http.Get(url)
rs := httprs.NewHttpReadSeeker(resp)
defer rs.Close()
io.ReadFull(rs, buf) // reads the first bytes from the response body
rs.Seek(1024, 0) // moves the position, but does no range request
io.ReadFull(rs, buf) // does a range request and reads from the response body
If you want use a specific http.Client for additional range requests :
rs := httprs.NewHttpReadSeeker(resp, client)
*/
package httprs
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/mitchellh/copystructure"
)
const shortSeekBytes = 1024
// A HttpReadSeeker reads from a http.Response.Body. It can Seek
// by doing range requests.
type HttpReadSeeker struct {
c *http.Client
req *http.Request
res *http.Response
r io.ReadCloser
pos int64
canSeek bool
Requests int
}
var _ io.ReadCloser = (*HttpReadSeeker)(nil)
var _ io.Seeker = (*HttpReadSeeker)(nil)
var (
// ErrNoContentLength is returned by Seek when the initial http response did not include a Content-Length header
ErrNoContentLength = errors.New("Content-Length was not set")
// ErrRangeRequestsNotSupported is returned by Seek and Read
// when the remote server does not allow range requests (Accept-Ranges was not set)
ErrRangeRequestsNotSupported = errors.New("Range requests are not supported by the remote server")
// ErrInvalidRange is returned by Read when trying to read past the end of the file
ErrInvalidRange = errors.New("Invalid range")
// ErrContentHasChanged is returned by Read when the content has changed since the first request
ErrContentHasChanged = errors.New("Content has changed since first request")
)
// NewHttpReadSeeker returns a HttpReadSeeker, using the http.Response and, optionaly, the http.Client
// that needs to be used for future range requests. If no http.Client is given, http.DefaultClient will
// be used.
//
// res.Request will be reused for range requests, headers may be added/removed
func NewHttpReadSeeker(res *http.Response, client ...*http.Client) *HttpReadSeeker {
r := &HttpReadSeeker{
req: res.Request,
res: res,
r: res.Body,
canSeek: (res.Header.Get("Accept-Ranges") == "bytes"),
}
if len(client) > 0 {
r.c = client[0]
} else {
r.c = http.DefaultClient
}
return r
}
// Clone clones the reader to enable parallel downloads of ranges
func (r *HttpReadSeeker) Clone() (*HttpReadSeeker, error) {
req, err := copystructure.Copy(r.req)
if err != nil {
return nil, err
}
return &HttpReadSeeker{
req: req.(*http.Request),
res: r.res,
r: nil,
canSeek: r.canSeek,
c: r.c,
}, nil
}
// Read reads from the response body. It does a range request if Seek was called before.
//
// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
func (r *HttpReadSeeker) Read(p []byte) (n int, err error) {
if r.r == nil {
err = r.rangeRequest()
}
if r.r != nil {
n, err = r.r.Read(p)
r.pos += int64(n)
}
return
}
// ReadAt reads from the response body starting at offset off.
//
// May return ErrRangeRequestsNotSupported, ErrInvalidRange or ErrContentHasChanged
func (r *HttpReadSeeker) ReadAt(p []byte, off int64) (n int, err error) {
r.Seek(off, 0)
return r.Read(p)
}
// Close closes the response body
func (r *HttpReadSeeker) Close() error {
if r.r != nil {
return r.r.Close()
}
return nil
}
// Seek moves the reader position to a new offset.
//
// It does not send http requests, allowing for multiple seeks without overhead.
// The http request will be sent by the next Read call.
//
// May return ErrNoContentLength or ErrRangeRequestsNotSupported
func (r *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
if !r.canSeek {
return 0, ErrRangeRequestsNotSupported
}
var err error
switch whence {
case 0:
case 1:
offset += r.pos
case 2:
if r.res.ContentLength <= 0 {
return 0, ErrNoContentLength
}
offset = r.res.ContentLength - offset
}
if r.r != nil {
// Try to read, which is cheaper than doing a request
if r.pos < offset && offset-r.pos <= shortSeekBytes {
_, err := io.CopyN(ioutil.Discard, r, offset-r.pos)
if err != nil {
return 0, err
}
}
if r.pos != offset {
err = r.r.Close()
r.r = nil
}
}
r.pos = offset
return r.pos, err
}
func (r *HttpReadSeeker) rangeRequest() error {
r.req.Header.Set("Range", fmt.Sprintf("bytes=%d-", r.pos))
etag, last := r.res.Header.Get("ETag"), r.res.Header.Get("Last-Modified")
switch {
case last != "":
r.req.Header.Set("If-Range", last)
case etag != "":
r.req.Header.Set("If-Range", etag)
}
r.Requests++
res, err := r.c.Do(r.req)
if err != nil {
return err
}
switch res.StatusCode {
case http.StatusRequestedRangeNotSatisfiable:
return ErrInvalidRange
case http.StatusOK:
return ErrContentHasChanged
case http.StatusPartialContent:
r.r = res.Body
return nil
}
return ErrRangeRequestsNotSupported
}
box: wercker/golang
build:
steps:
- setup-go-workspace
- script:
name: Install goconvey
code: go get github.com/smartystreets/goconvey/convey
- script:
name: Go get
code: go get -v ./...
- script:
name: Go test
code: go test -p 1 -v ./...
The MIT License (MIT)
Copyright (c) 2014 Mitchell Hashimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# copystructure
copystructure is a Go library for deep copying values in Go.
This allows you to copy Go values that may contain reference values
such as maps, slices, or pointers, and copy their data as well instead
of just their references.
## Installation
Standard `go get`:
```
$ go get github.com/mitchellh/copystructure
```
## Usage & Example
For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/copystructure).
The `Copy` function has examples associated with it there.
package copystructure
import (
"reflect"
"time"
)
func init() {
Copiers[reflect.TypeOf(time.Time{})] = timeCopier
}
func timeCopier(v interface{}) (interface{}, error) {
// Just... copy it.
return v.(time.Time), nil
}
This diff is collapsed.
The MIT License (MIT)
Copyright (c) 2013 Mitchell Hashimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# reflectwalk
reflectwalk is a Go library for "walking" a value in Go using reflection,
in the same way a directory tree can be "walked" on the filesystem. Walking
a complex structure can allow you to do manipulations on unknown structures
such as those decoded from JSON.
package reflectwalk
//go:generate stringer -type=Location location.go
type Location uint
const (
None Location = iota
Map
MapKey
MapValue
Slice
SliceElem
Array
ArrayElem
Struct
StructField
WalkLoc
)
// generated by stringer -type=Location location.go; DO NOT EDIT
package reflectwalk
import "fmt"
const _Location_name = "NoneMapMapKeyMapValueSliceSliceElemStructStructFieldWalkLoc"
var _Location_index = [...]uint8{0, 4, 7, 13, 21, 26, 35, 41, 52, 59}
func (i Location) String() string {
if i+1 >= Location(len(_Location_index)) {
return fmt.Sprintf("Location(%d)", i)
}
return _Location_name[_Location_index[i]:_Location_index[i+1]]
}
// reflectwalk is a package that allows you to "walk" complex structures
// similar to how you may "walk" a filesystem: visiting every element one
// by one and calling callback functions allowing you to handle and manipulate
// those elements.
package reflectwalk
import (
"errors"
"reflect"
)
// PrimitiveWalker implementations are able to handle primitive values
// within complex structures. Primitive values are numbers, strings,
// booleans, funcs, chans.
//
// These primitive values are often members of more complex
// structures (slices, maps, etc.) that are walkable by other interfaces.
type PrimitiveWalker interface {
Primitive(reflect.Value) error
}
// InterfaceWalker implementations are able to handle interface values as they
// are encountered during the walk.
type InterfaceWalker interface {
Interface(reflect.Value) error
}
// MapWalker implementations are able to handle individual elements
// found within a map structure.
type MapWalker interface {
Map(m reflect.Value) error
MapElem(m, k, v reflect.Value) error
}
// SliceWalker implementations are able to handle slice elements found
// within complex structures.
type SliceWalker interface {
Slice(reflect.Value) error
SliceElem(int, reflect.Value) error
}
// ArrayWalker implementations are able to handle array elements found
// within complex structures.
type ArrayWalker interface {
Array(reflect.Value) error
ArrayElem(int, reflect.Value) error
}
// StructWalker is an interface that has methods that are called for
// structs when a Walk is done.
type StructWalker interface {
Struct(reflect.Value) error
StructField(reflect.StructField, reflect.Value) error
}
// EnterExitWalker implementations are notified before and after
// they walk deeper into complex structures (into struct fields,
// into slice elements, etc.)
type EnterExitWalker interface {
Enter(Location) error
Exit(Location) error
}
// PointerWalker implementations are notified when the value they're
// walking is a pointer or not. Pointer is called for _every_ value whether
// it is a pointer or not.
type PointerWalker interface {
PointerEnter(bool) error
PointerExit(bool) error
}
// SkipEntry can be returned from walk functions to skip walking
// the value of this field. This is only valid in the following functions:
//
// - Struct: skips all fields from being walked
// - StructField: skips walking the struct value
//
var SkipEntry = errors.New("skip this entry")
// Walk takes an arbitrary value and an interface and traverses the
// value, calling callbacks on the interface if they are supported.
// The interface should implement one or more of the walker interfaces
// in this package, such as PrimitiveWalker, StructWalker, etc.
func Walk(data, walker interface{}) (err error) {
v := reflect.ValueOf(data)
ew, ok := walker.(EnterExitWalker)
if ok {
err = ew.Enter(WalkLoc)
}
if err == nil {
err = walk(v, walker)
}
if ok && err == nil {
err = ew.Exit(WalkLoc)
}
return
}
func walk(v reflect.Value, w interface{}) (err error) {
// Determine if we're receiving a pointer and if so notify the walker.
// The logic here is convoluted but very important (tests will fail if
// almost any part is changed). I will try to explain here.
//
// First, we check if the value is an interface, if so, we really need
// to check the interface's VALUE to see whether it is a pointer.
//
// Check whether the value is then a pointer. If so, then set pointer
// to true to notify the user.
//
// If we still have a pointer or an interface after the indirections, then
// we unwrap another level
//
// At this time, we also set "v" to be the dereferenced value. This is
// because once we've unwrapped the pointer we want to use that value.
pointer := false
pointerV := v
for {
if pointerV.Kind() == reflect.Interface {
if iw, ok := w.(InterfaceWalker); ok {
if err = iw.Interface(pointerV); err != nil {
return
}
}
pointerV = pointerV.Elem()
}
if pointerV.Kind() == reflect.Ptr {
pointer = true
v = reflect.Indirect(pointerV)
}
if pw, ok := w.(PointerWalker); ok {
if err = pw.PointerEnter(pointer); err != nil {
return
}
defer func(pointer bool) {
if err != nil {
return
}
err = pw.PointerExit(pointer)
}(pointer)
}
if pointer {
pointerV = v
}
pointer = false
// If we still have a pointer or interface we have to indirect another level.
switch pointerV.Kind() {
case reflect.Ptr, reflect.Interface:
continue
}
break
}
// We preserve the original value here because if it is an interface
// type, we want to pass that directly into the walkPrimitive, so that
// we can set it.
originalV := v
if v.Kind() == reflect.Interface {
v = v.Elem()
}
k := v.Kind()
if k >= reflect.Int && k <= reflect.Complex128 {
k = reflect.Int
}
switch k {
// Primitives
case reflect.Bool, reflect.Chan, reflect.Func, reflect.Int, reflect.String, reflect.Invalid:
err = walkPrimitive(originalV, w)
return
case reflect.Map:
err = walkMap(v, w)
return
case reflect.Slice:
err = walkSlice(v, w)
return
case reflect.Struct:
err = walkStruct(v, w)
return
case reflect.Array:
err = walkArray(v, w)
return
default:
panic("unsupported type: " + k.String())
}
}
func walkMap(v reflect.Value, w interface{}) error {
ew, ewok := w.(EnterExitWalker)
if ewok {
ew.Enter(Map)
}
if mw, ok := w.(MapWalker); ok {
if err := mw.Map(v); err != nil {
return err
}
}
for _, k := range v.MapKeys() {
kv := v.MapIndex(k)
if mw, ok := w.(MapWalker); ok {
if err := mw.MapElem(v, k, kv); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(MapKey)
}
if err := walk(k, w); err != nil {
return err
}
if ok {
ew.Exit(MapKey)
ew.Enter(MapValue)
}
if err := walk(kv, w); err != nil {
return err
}
if ok {
ew.Exit(MapValue)
}
}
if ewok {
ew.Exit(Map)
}
return nil
}
func walkPrimitive(v reflect.Value, w interface{}) error {
if pw, ok := w.(PrimitiveWalker); ok {
return pw.Primitive(v)
}
return nil
}
func walkSlice(v reflect.Value, w interface{}) (err error) {
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(Slice)
}
if sw, ok := w.(SliceWalker); ok {
if err := sw.Slice(v); err != nil {
return err
}
}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if sw, ok := w.(SliceWalker); ok {
if err := sw.SliceElem(i, elem); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(SliceElem)
}
if err := walk(elem, w); err != nil {
return err
}
if ok {
ew.Exit(SliceElem)
}
}
ew, ok = w.(EnterExitWalker)
if ok {
ew.Exit(Slice)
}
return nil
}
func walkArray(v reflect.Value, w interface{}) (err error) {
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(Array)
}
if aw, ok := w.(ArrayWalker); ok {
if err := aw.Array(v); err != nil {
return err
}
}
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if aw, ok := w.(ArrayWalker); ok {
if err := aw.ArrayElem(i, elem); err != nil {
return err
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(ArrayElem)
}
if err := walk(elem, w); err != nil {
return err
}
if ok {
ew.Exit(ArrayElem)
}
}
ew, ok = w.(EnterExitWalker)
if ok {
ew.Exit(Array)
}
return nil
}
func walkStruct(v reflect.Value, w interface{}) (err error) {
ew, ewok := w.(EnterExitWalker)
if ewok {
ew.Enter(Struct)
}
skip := false
if sw, ok := w.(StructWalker); ok {
err = sw.Struct(v)
if err == SkipEntry {
skip = true
err = nil
}
if err != nil {
return
}
}
if !skip {
vt := v.Type()
for i := 0; i < vt.NumField(); i++ {
sf := vt.Field(i)
f := v.FieldByIndex([]int{i})
if sw, ok := w.(StructWalker); ok {
err = sw.StructField(sf, f)
// SkipEntry just pretends this field doesn't even exist
if err == SkipEntry {
continue
}
if err != nil {
return
}
}
ew, ok := w.(EnterExitWalker)
if ok {
ew.Enter(StructField)
}
err = walk(f, w)
if err != nil {
return
}
if ok {
ew.Exit(StructField)
}
}
}
if ewok {
ew.Exit(Struct)
}
return nil
}
...@@ -53,6 +53,12 @@ ...@@ -53,6 +53,12 @@
"path": "github.com/gorilla/websocket", "path": "github.com/gorilla/websocket",
"revision": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a" "revision": "e8f0f8aaa98dfb6586cbdf2978d511e3199a960a"
}, },
{
"checksumSHA1": "g46WnPAlsmkyNDUFdObDgP6DP+s=",
"path": "github.com/jfbus/httprs",
"revision": "e879ae6bf984ca6c009eb44696e9f400ec3ed2d9",
"revisionTime": "2017-07-03T13:54:35Z"
},
{ {
"checksumSHA1": "oIkoHb8+rM5Etur5HhZVY/sDQKQ=", "checksumSHA1": "oIkoHb8+rM5Etur5HhZVY/sDQKQ=",
"path": "github.com/jpillora/backoff", "path": "github.com/jpillora/backoff",
...@@ -65,6 +71,18 @@ ...@@ -65,6 +71,18 @@
"path": "github.com/matttproud/golang_protobuf_extensions/pbutil", "path": "github.com/matttproud/golang_protobuf_extensions/pbutil",
"revision": "c12348ce28de40eed0136aa2b644d0ee0650e56c" "revision": "c12348ce28de40eed0136aa2b644d0ee0650e56c"
}, },
{
"checksumSHA1": "kSmDazz+cokgcHQT7q56Na+IBe0=",
"path": "github.com/mitchellh/copystructure",
"revision": "f81071c9d77b7931f78c90b416a074ecdc50e959",
"revisionTime": "2017-01-16T00:44:49Z"
},
{
"checksumSHA1": "KqsMqI+Y+3EFYPhyzafpIneaVCM=",
"path": "github.com/mitchellh/reflectwalk",
"revision": "8d802ff4ae93611b807597f639c19f76074df5c6",
"revisionTime": "2017-05-08T17:38:06Z"
},
{ {
"checksumSHA1": "LuFv4/jlrmFNnDb/5SCSEPAM9vU=", "checksumSHA1": "LuFv4/jlrmFNnDb/5SCSEPAM9vU=",
"path": "github.com/pmezard/go-difflib/difflib", "path": "github.com/pmezard/go-difflib/difflib",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment