Commit b880543f authored by Robert Griesemer's avatar Robert Griesemer

go/ast: generalized ast filtering

R=rsc
CC=golang-dev
https://golang.org/cl/788041
parent 299cd38f
...@@ -358,6 +358,8 @@ func (x *ChanType) exprNode() {} ...@@ -358,6 +358,8 @@ func (x *ChanType) exprNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Convenience functions for Idents // Convenience functions for Idents
var noPos token.Position
// NewIdent creates a new Ident without position and minimal object // NewIdent creates a new Ident without position and minimal object
// information. Useful for ASTs generated by code other than the Go // information. Useful for ASTs generated by code other than the Go
// parser. // parser.
......
...@@ -6,8 +6,10 @@ package ast ...@@ -6,8 +6,10 @@ package ast
import "go/token" import "go/token"
// ----------------------------------------------------------------------------
// Export filtering
func filterIdentList(list []*Ident) []*Ident { func identListExports(list []*Ident) []*Ident {
j := 0 j := 0
for _, x := range list { for _, x := range list {
if x.IsExported() { if x.IsExported() {
...@@ -36,7 +38,7 @@ func isExportedType(typ Expr) bool { ...@@ -36,7 +38,7 @@ func isExportedType(typ Expr) bool {
} }
func filterFieldList(fields *FieldList, incomplete *bool) { func fieldListExports(fields *FieldList, incomplete *bool) {
if fields == nil { if fields == nil {
return return
} }
...@@ -54,14 +56,14 @@ func filterFieldList(fields *FieldList, incomplete *bool) { ...@@ -54,14 +56,14 @@ func filterFieldList(fields *FieldList, incomplete *bool) {
exported = isExportedType(f.Type) exported = isExportedType(f.Type)
} else { } else {
n := len(f.Names) n := len(f.Names)
f.Names = filterIdentList(f.Names) f.Names = identListExports(f.Names)
if len(f.Names) < n { if len(f.Names) < n {
*incomplete = true *incomplete = true
} }
exported = len(f.Names) > 0 exported = len(f.Names) > 0
} }
if exported { if exported {
filterType(f.Type) typeExports(f.Type)
list[j] = f list[j] = f
j++ j++
} }
...@@ -73,49 +75,47 @@ func filterFieldList(fields *FieldList, incomplete *bool) { ...@@ -73,49 +75,47 @@ func filterFieldList(fields *FieldList, incomplete *bool) {
} }
func filterParamList(fields *FieldList) { func paramListExports(fields *FieldList) {
if fields == nil { if fields == nil {
return return
} }
for _, f := range fields.List { for _, f := range fields.List {
filterType(f.Type) typeExports(f.Type)
} }
} }
var noPos token.Position func typeExports(typ Expr) {
func filterType(typ Expr) {
switch t := typ.(type) { switch t := typ.(type) {
case *ArrayType: case *ArrayType:
filterType(t.Elt) typeExports(t.Elt)
case *StructType: case *StructType:
filterFieldList(t.Fields, &t.Incomplete) fieldListExports(t.Fields, &t.Incomplete)
case *FuncType: case *FuncType:
filterParamList(t.Params) paramListExports(t.Params)
filterParamList(t.Results) paramListExports(t.Results)
case *InterfaceType: case *InterfaceType:
filterFieldList(t.Methods, &t.Incomplete) fieldListExports(t.Methods, &t.Incomplete)
case *MapType: case *MapType:
filterType(t.Key) typeExports(t.Key)
filterType(t.Value) typeExports(t.Value)
case *ChanType: case *ChanType:
filterType(t.Value) typeExports(t.Value)
} }
} }
func filterSpec(spec Spec) bool { func specExports(spec Spec) bool {
switch s := spec.(type) { switch s := spec.(type) {
case *ValueSpec: case *ValueSpec:
s.Names = filterIdentList(s.Names) s.Names = identListExports(s.Names)
if len(s.Names) > 0 { if len(s.Names) > 0 {
filterType(s.Type) typeExports(s.Type)
return true return true
} }
case *TypeSpec: case *TypeSpec:
if s.Name.IsExported() { if s.Name.IsExported() {
filterType(s.Type) typeExports(s.Type)
return true return true
} }
} }
...@@ -123,10 +123,10 @@ func filterSpec(spec Spec) bool { ...@@ -123,10 +123,10 @@ func filterSpec(spec Spec) bool {
} }
func filterSpecList(list []Spec) []Spec { func specListExports(list []Spec) []Spec {
j := 0 j := 0
for _, s := range list { for _, s := range list {
if filterSpec(s) { if specExports(s) {
list[j] = s list[j] = s
j++ j++
} }
...@@ -135,10 +135,10 @@ func filterSpecList(list []Spec) []Spec { ...@@ -135,10 +135,10 @@ func filterSpecList(list []Spec) []Spec {
} }
func filterDecl(decl Decl) bool { func declExports(decl Decl) bool {
switch d := decl.(type) { switch d := decl.(type) {
case *GenDecl: case *GenDecl:
d.Specs = filterSpecList(d.Specs) d.Specs = specListExports(d.Specs)
return len(d.Specs) > 0 return len(d.Specs) > 0
case *FuncDecl: case *FuncDecl:
d.Body = nil // strip body d.Body = nil // strip body
...@@ -161,7 +161,7 @@ func filterDecl(decl Decl) bool { ...@@ -161,7 +161,7 @@ func filterDecl(decl Decl) bool {
func FileExports(src *File) bool { func FileExports(src *File) bool {
j := 0 j := 0
for _, d := range src.Decls { for _, d := range src.Decls {
if filterDecl(d) { if declExports(d) {
src.Decls[j] = d src.Decls[j] = d
j++ j++
} }
...@@ -189,6 +189,107 @@ func PackageExports(pkg *Package) bool { ...@@ -189,6 +189,107 @@ func PackageExports(pkg *Package) bool {
} }
// ----------------------------------------------------------------------------
// General filtering
type Filter func(string) bool
func filterIdentList(list []*Ident, f Filter) []*Ident {
j := 0
for _, x := range list {
if f(x.Name()) {
list[j] = x
j++
}
}
return list[0:j]
}
func filterSpec(spec Spec, f Filter) bool {
switch s := spec.(type) {
case *ValueSpec:
s.Names = filterIdentList(s.Names, f)
return len(s.Names) > 0
case *TypeSpec:
return f(s.Name.Name())
}
return false
}
func filterSpecList(list []Spec, f Filter) []Spec {
j := 0
for _, s := range list {
if filterSpec(s, f) {
list[j] = s
j++
}
}
return list[0:j]
}
func filterDecl(decl Decl, f Filter) bool {
switch d := decl.(type) {
case *GenDecl:
d.Specs = filterSpecList(d.Specs, f)
return len(d.Specs) > 0
case *FuncDecl:
return f(d.Name.Name())
}
return false
}
// FilterFile trims the AST for a Go file in place by removing all
// names from top-level declarations (but not from parameter lists
// or inside types) that don't pass through the filter f. If the
// declaration is empty afterwards, the declaration is removed from
// the AST.
// The File.comments list is not changed.
//
// FilterFile returns true if there are any top-level declarations
// left after filtering; it returns false otherwise.
//
func FilterFile(src *File, f Filter) bool {
j := 0
for _, d := range src.Decls {
if filterDecl(d, f) {
src.Decls[j] = d
j++
}
}
src.Decls = src.Decls[0:j]
return j > 0
}
// FilterPackage trims the AST for a Go package in place by removing all
// names from top-level declarations (but not from parameter lists
// or inside types) that don't pass through the filter f. If the
// declaration is empty afterwards, the declaration is removed from
// the AST.
// The pkg.Files list is not changed, so that file names and top-level
// package comments don't get lost.
//
// FilterPackage returns true if there are any top-level declarations
// left after filtering; it returns false otherwise.
//
func FilterPackage(pkg *Package, f Filter) bool {
hasDecls := false
for _, src := range pkg.Files {
if FilterFile(src, f) {
hasDecls = true
}
}
return hasDecls
}
// ----------------------------------------------------------------------------
// Merging of package files
// separator is an empty //-style comment that is interspersed between // separator is an empty //-style comment that is interspersed between
// different comment groups when they are concatenated into a single group // different comment groups when they are concatenated into a single group
// //
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment