Commit 877c1892 authored by Russ Cox's avatar Russ Cox

gofix: add -diff, various fixes and helpers

  * add -diff command line option
  * use scoping information in refersTo, isPkgDot, isPtrPkgDot.
  * add new scoping-based helpers countUses, rewriteUses, assignsTo, isTopName.
  * rename rewrite to walk, add walkBeforeAfter.
  * add toy typechecker, a placeholder for go/types

R=gri
CC=golang-dev
https://golang.org/cl/4285053
parent fb175cf7
This diff is collapsed.
...@@ -41,7 +41,7 @@ func httpserver(f *ast.File) bool { ...@@ -41,7 +41,7 @@ func httpserver(f *ast.File) bool {
if !ok { if !ok {
continue continue
} }
rewrite(fn.Body, func(n interface{}) { walk(fn.Body, func(n interface{}) {
// Want to replace expression sometimes, // Want to replace expression sometimes,
// so record pointer to it for updating below. // so record pointer to it for updating below.
ptr, ok := n.(*ast.Expr) ptr, ok := n.(*ast.Expr)
......
...@@ -6,6 +6,7 @@ package main ...@@ -6,6 +6,7 @@ package main
import ( import (
"bytes" "bytes"
"exec"
"flag" "flag"
"fmt" "fmt"
"go/parser" "go/parser"
...@@ -29,8 +30,10 @@ var allowedRewrites = flag.String("r", "", ...@@ -29,8 +30,10 @@ var allowedRewrites = flag.String("r", "",
var allowed map[string]bool var allowed map[string]bool
var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
func usage() { func usage() {
fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n") fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [path ...]\n")
flag.PrintDefaults() flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
for _, f := range fixes { for _, f := range fixes {
...@@ -85,10 +88,16 @@ const ( ...@@ -85,10 +88,16 @@ const (
printerMode = printer.TabIndent | printer.UseSpaces printerMode = printer.TabIndent | printer.UseSpaces
) )
var printConfig = &printer.Config{
printerMode,
tabWidth,
}
func processFile(filename string, useStdin bool) os.Error { func processFile(filename string, useStdin bool) os.Error {
var f *os.File var f *os.File
var err os.Error var err os.Error
var fixlog bytes.Buffer
var buf bytes.Buffer
if useStdin { if useStdin {
f = os.Stdin f = os.Stdin
...@@ -110,34 +119,77 @@ func processFile(filename string, useStdin bool) os.Error { ...@@ -110,34 +119,77 @@ func processFile(filename string, useStdin bool) os.Error {
return err return err
} }
// Apply all fixes to file.
newFile := file
fixed := false fixed := false
var buf bytes.Buffer
for _, fix := range fixes { for _, fix := range fixes {
if allowed != nil && !allowed[fix.desc] { if allowed != nil && !allowed[fix.desc] {
continue continue
} }
if fix.f(file) { if fix.f(newFile) {
fixed = true fixed = true
fmt.Fprintf(&buf, " %s", fix.name) fmt.Fprintf(&fixlog, " %s", fix.name)
// AST changed.
// Print and parse, to update any missing scoping
// or position information for subsequent fixers.
buf.Reset()
_, err = printConfig.Fprint(&buf, fset, newFile)
if err != nil {
return err
}
newSrc := buf.Bytes()
newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
if err != nil {
return err
}
} }
} }
if !fixed { if !fixed {
return nil return nil
} }
fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, buf.String()[1:]) fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
// Print AST. We did that after each fix, so this appears
// redundant, but it is necessary to generate gofmt-compatible
// source code in a few cases. The official gofmt style is the
// output of the printer run on a standard AST generated by the parser,
// but the source we generated inside the loop above is the
// output of the printer run on a mangled AST generated by a fixer.
buf.Reset() buf.Reset()
_, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) _, err = printConfig.Fprint(&buf, fset, newFile)
if err != nil { if err != nil {
return err return err
} }
newSrc := buf.Bytes()
if *doDiff {
data, err := diff(src, newSrc)
if err != nil {
return fmt.Errorf("computing diff: %s", err)
}
fmt.Printf("diff %s fixed/%s\n", filename, filename)
os.Stdout.Write(data)
return nil
}
if useStdin { if useStdin {
os.Stdout.Write(buf.Bytes()) os.Stdout.Write(newSrc)
return nil return nil
} }
return ioutil.WriteFile(f.Name(), buf.Bytes(), 0) return ioutil.WriteFile(f.Name(), newSrc, 0)
}
var gofmtBuf bytes.Buffer
func gofmt(n interface{}) string {
gofmtBuf.Reset()
_, err := printConfig.Fprint(&gofmtBuf, fset, n)
if err != nil {
return "<" + err.String() + ">"
}
return gofmtBuf.String()
} }
func report(err os.Error) { func report(err os.Error) {
...@@ -177,3 +229,36 @@ func isGoFile(f *os.FileInfo) bool { ...@@ -177,3 +229,36 @@ func isGoFile(f *os.FileInfo) bool {
// ignore non-Go files // ignore non-Go files
return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go") return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go")
} }
func diff(b1, b2 []byte) (data []byte, err os.Error) {
f1, err := ioutil.TempFile("", "gofix")
if err != nil {
return nil, err
}
defer os.Remove(f1.Name())
defer f1.Close()
f2, err := ioutil.TempFile("", "gofix")
if err != nil {
return nil, err
}
defer os.Remove(f2.Name())
defer f2.Close()
f1.Write(b1)
f2.Write(b2)
diffcmd, err := exec.LookPath("diff")
if err != nil {
return nil, err
}
c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "",
exec.DevNull, exec.Pipe, exec.MergeWithStdout)
if err != nil {
return nil, err
}
defer c.Close()
return ioutil.ReadAll(c.Stdout)
}
...@@ -6,12 +6,10 @@ package main ...@@ -6,12 +6,10 @@ package main
import ( import (
"bytes" "bytes"
"exec"
"go/ast" "go/ast"
"go/parser" "go/parser"
"go/printer" "go/printer"
"io/ioutil" "strings"
"os"
"testing" "testing"
) )
...@@ -28,6 +26,8 @@ func addTestCases(t []testCase) { ...@@ -28,6 +26,8 @@ func addTestCases(t []testCase) {
testCases = append(testCases, t...) testCases = append(testCases, t...)
} }
func fnop(*ast.File) bool { return false }
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) { func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode) file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil { if err != nil {
...@@ -42,7 +42,7 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out ...@@ -42,7 +42,7 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out
t.Errorf("%s: printing: %v", desc, err) t.Errorf("%s: printing: %v", desc, err)
return return
} }
if s := buf.String(); in != s { if s := buf.String(); in != s && fn != fnop {
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, desc, in, desc, s) desc, desc, in, desc, s)
tdiff(t, in, s) tdiff(t, in, s)
...@@ -77,8 +77,17 @@ func TestRewrite(t *testing.T) { ...@@ -77,8 +77,17 @@ func TestRewrite(t *testing.T) {
continue continue
} }
// reformat to get printing right
out, _, ok = parseFixPrint(t, fnop, tt.Name, out)
if !ok {
continue
}
if out != tt.Out { if out != tt.Out {
t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out) t.Errorf("%s: incorrect output.\n", tt.Name)
if !strings.HasPrefix(tt.Name, "testdata/") {
t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
}
tdiff(t, out, tt.Out) tdiff(t, out, tt.Out)
continue continue
} }
...@@ -108,44 +117,10 @@ func TestRewrite(t *testing.T) { ...@@ -108,44 +117,10 @@ func TestRewrite(t *testing.T) {
} }
func tdiff(t *testing.T, a, b string) { func tdiff(t *testing.T, a, b string) {
f1, err := ioutil.TempFile("", "gofix") data, err := diff([]byte(a), []byte(b))
if err != nil {
t.Error(err)
return
}
defer os.Remove(f1.Name())
defer f1.Close()
f2, err := ioutil.TempFile("", "gofix")
if err != nil {
t.Error(err)
return
}
defer os.Remove(f2.Name())
defer f2.Close()
f1.Write([]byte(a))
f2.Write([]byte(b))
diffcmd, err := exec.LookPath("diff")
if err != nil {
t.Error(err)
return
}
c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "",
exec.DevNull, exec.Pipe, exec.MergeWithStdout)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer c.Close()
data, err := ioutil.ReadAll(c.Stdout)
if err != nil {
t.Error(err)
return
}
t.Error(string(data)) t.Error(string(data))
} }
...@@ -47,7 +47,7 @@ func netdial(f *ast.File) bool { ...@@ -47,7 +47,7 @@ func netdial(f *ast.File) bool {
} }
fixed := false fixed := false
rewrite(f, func(n interface{}) { walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr) call, ok := n.(*ast.CallExpr)
if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 { if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 {
return return
...@@ -70,7 +70,7 @@ func tlsdial(f *ast.File) bool { ...@@ -70,7 +70,7 @@ func tlsdial(f *ast.File) bool {
} }
fixed := false fixed := false
rewrite(f, func(n interface{}) { walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr) call, ok := n.(*ast.CallExpr)
if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 { if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 {
return return
...@@ -94,7 +94,7 @@ func netlookup(f *ast.File) bool { ...@@ -94,7 +94,7 @@ func netlookup(f *ast.File) bool {
} }
fixed := false fixed := false
rewrite(f, func(n interface{}) { walk(f, func(n interface{}) {
as, ok := n.(*ast.AssignStmt) as, ok := n.(*ast.AssignStmt)
if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 { if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 {
return return
......
...@@ -27,7 +27,7 @@ func osopen(f *ast.File) bool { ...@@ -27,7 +27,7 @@ func osopen(f *ast.File) bool {
} }
fixed := false fixed := false
rewrite(f, func(n interface{}) { walk(f, func(n interface{}) {
// Rename O_CREAT to O_CREATE. // Rename O_CREAT to O_CREATE.
if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") { if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") {
expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE" expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE"
......
...@@ -28,7 +28,7 @@ func procattr(f *ast.File) bool { ...@@ -28,7 +28,7 @@ func procattr(f *ast.File) bool {
} }
fixed := false fixed := false
rewrite(f, func(n interface{}) { walk(f, func(n interface{}) {
call, ok := n.(*ast.CallExpr) call, ok := n.(*ast.CallExpr)
if !ok || len(call.Args) != 5 { if !ok || len(call.Args) != 5 {
return return
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment