311 lines
8.2 KiB
Go
311 lines
8.2 KiB
Go
// Copyright 2009 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
func initRewrite() {
|
|
if *rewriteRule == "" {
|
|
rewrite = nil // disable any previous rewrite
|
|
return
|
|
}
|
|
f := strings.Split(*rewriteRule, "->")
|
|
if len(f) != 2 {
|
|
fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
|
|
os.Exit(2)
|
|
}
|
|
pattern := parseExpr(f[0], "pattern")
|
|
replace := parseExpr(f[1], "replacement")
|
|
rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
|
|
return rewriteFile(fset, pattern, replace, p)
|
|
}
|
|
}
|
|
|
|
// parseExpr parses s as an expression.
|
|
// It might make sense to expand this to allow statement patterns,
|
|
// but there are problems with preserving formatting and also
|
|
// with what a wildcard for a statement looks like.
|
|
func parseExpr(s, what string) ast.Expr {
|
|
x, err := parser.ParseExpr(s)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
|
|
os.Exit(2)
|
|
}
|
|
return x
|
|
}
|
|
|
|
// Keep this function for debugging.
|
|
/*
|
|
func dump(msg string, val reflect.Value) {
|
|
fmt.Printf("%s:\n", msg)
|
|
ast.Print(fileSet, val.Interface())
|
|
fmt.Println()
|
|
}
|
|
*/
|
|
|
|
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
|
|
func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
|
|
cmap := ast.NewCommentMap(fileSet, p, p.Comments)
|
|
m := make(map[string]reflect.Value)
|
|
pat := reflect.ValueOf(pattern)
|
|
repl := reflect.ValueOf(replace)
|
|
|
|
var rewriteVal func(val reflect.Value) reflect.Value
|
|
rewriteVal = func(val reflect.Value) reflect.Value {
|
|
// don't bother if val is invalid to start with
|
|
if !val.IsValid() {
|
|
return reflect.Value{}
|
|
}
|
|
val = apply(rewriteVal, val)
|
|
for k := range m {
|
|
delete(m, k)
|
|
}
|
|
if match(m, pat, val) {
|
|
val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
|
|
}
|
|
return val
|
|
}
|
|
|
|
r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
|
|
r.Comments = cmap.Filter(r).Comments() // recreate comments list
|
|
return r
|
|
}
|
|
|
|
// set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
|
|
func set(x, y reflect.Value) {
|
|
// don't bother if x cannot be set or y is invalid
|
|
if !x.CanSet() || !y.IsValid() {
|
|
return
|
|
}
|
|
defer func() {
|
|
if x := recover(); x != nil {
|
|
if s, ok := x.(string); ok &&
|
|
(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
|
|
// x cannot be set to y - ignore this rewrite
|
|
return
|
|
}
|
|
panic(x)
|
|
}
|
|
}()
|
|
x.Set(y)
|
|
}
|
|
|
|
// Values/types for special cases.
|
|
var (
|
|
objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
|
|
scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
|
|
|
|
identType = reflect.TypeOf((*ast.Ident)(nil))
|
|
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
|
|
positionType = reflect.TypeOf(token.NoPos)
|
|
callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
|
|
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
|
|
)
|
|
|
|
// apply replaces each AST field x in val with f(x), returning val.
|
|
// To avoid extra conversions, f operates on the reflect.Value form.
|
|
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
|
|
if !val.IsValid() {
|
|
return reflect.Value{}
|
|
}
|
|
|
|
// *ast.Objects introduce cycles and are likely incorrect after
|
|
// rewrite; don't follow them but replace with nil instead
|
|
if val.Type() == objectPtrType {
|
|
return objectPtrNil
|
|
}
|
|
|
|
// similarly for scopes: they are likely incorrect after a rewrite;
|
|
// replace them with nil
|
|
if val.Type() == scopePtrType {
|
|
return scopePtrNil
|
|
}
|
|
|
|
switch v := reflect.Indirect(val); v.Kind() {
|
|
case reflect.Slice:
|
|
for i := 0; i < v.Len(); i++ {
|
|
e := v.Index(i)
|
|
set(e, f(e))
|
|
}
|
|
case reflect.Struct:
|
|
for i := 0; i < v.NumField(); i++ {
|
|
e := v.Field(i)
|
|
set(e, f(e))
|
|
}
|
|
case reflect.Interface:
|
|
e := v.Elem()
|
|
set(v, f(e))
|
|
}
|
|
return val
|
|
}
|
|
|
|
func isWildcard(s string) bool {
|
|
rune, size := utf8.DecodeRuneInString(s)
|
|
return size == len(s) && unicode.IsLower(rune)
|
|
}
|
|
|
|
// match reports whether pattern matches val,
|
|
// recording wildcard submatches in m.
|
|
// If m == nil, match checks whether pattern == val.
|
|
func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
|
|
// Wildcard matches any expression. If it appears multiple
|
|
// times in the pattern, it must match the same expression
|
|
// each time.
|
|
if m != nil && pattern.IsValid() && pattern.Type() == identType {
|
|
name := pattern.Interface().(*ast.Ident).Name
|
|
if isWildcard(name) && val.IsValid() {
|
|
// wildcards only match valid (non-nil) expressions.
|
|
if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
|
|
if old, ok := m[name]; ok {
|
|
return match(nil, old, val)
|
|
}
|
|
m[name] = val
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Otherwise, pattern and val must match recursively.
|
|
if !pattern.IsValid() || !val.IsValid() {
|
|
return !pattern.IsValid() && !val.IsValid()
|
|
}
|
|
if pattern.Type() != val.Type() {
|
|
return false
|
|
}
|
|
|
|
// Special cases.
|
|
switch pattern.Type() {
|
|
case identType:
|
|
// For identifiers, only the names need to match
|
|
// (and none of the other *ast.Object information).
|
|
// This is a common case, handle it all here instead
|
|
// of recursing down any further via reflection.
|
|
p := pattern.Interface().(*ast.Ident)
|
|
v := val.Interface().(*ast.Ident)
|
|
return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
|
|
case objectPtrType, positionType:
|
|
// object pointers and token positions always match
|
|
return true
|
|
case callExprType:
|
|
// For calls, the Ellipsis fields (token.Position) must
|
|
// match since that is how f(x) and f(x...) are different.
|
|
// Check them here but fall through for the remaining fields.
|
|
p := pattern.Interface().(*ast.CallExpr)
|
|
v := val.Interface().(*ast.CallExpr)
|
|
if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
|
|
return false
|
|
}
|
|
}
|
|
|
|
p := reflect.Indirect(pattern)
|
|
v := reflect.Indirect(val)
|
|
if !p.IsValid() || !v.IsValid() {
|
|
return !p.IsValid() && !v.IsValid()
|
|
}
|
|
|
|
switch p.Kind() {
|
|
case reflect.Slice:
|
|
if p.Len() != v.Len() {
|
|
return false
|
|
}
|
|
for i := 0; i < p.Len(); i++ {
|
|
if !match(m, p.Index(i), v.Index(i)) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
|
|
case reflect.Struct:
|
|
for i := 0; i < p.NumField(); i++ {
|
|
if !match(m, p.Field(i), v.Field(i)) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
|
|
case reflect.Interface:
|
|
return match(m, p.Elem(), v.Elem())
|
|
}
|
|
|
|
// Handle token integers, etc.
|
|
return p.Interface() == v.Interface()
|
|
}
|
|
|
|
// subst returns a copy of pattern with values from m substituted in place
|
|
// of wildcards and pos used as the position of tokens from the pattern.
|
|
// if m == nil, subst returns a copy of pattern and doesn't change the line
|
|
// number information.
|
|
func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
|
|
if !pattern.IsValid() {
|
|
return reflect.Value{}
|
|
}
|
|
|
|
// Wildcard gets replaced with map value.
|
|
if m != nil && pattern.Type() == identType {
|
|
name := pattern.Interface().(*ast.Ident).Name
|
|
if isWildcard(name) {
|
|
if old, ok := m[name]; ok {
|
|
return subst(nil, old, reflect.Value{})
|
|
}
|
|
}
|
|
}
|
|
|
|
if pos.IsValid() && pattern.Type() == positionType {
|
|
// use new position only if old position was valid in the first place
|
|
if old := pattern.Interface().(token.Pos); !old.IsValid() {
|
|
return pattern
|
|
}
|
|
return pos
|
|
}
|
|
|
|
// Otherwise copy.
|
|
switch p := pattern; p.Kind() {
|
|
case reflect.Slice:
|
|
if p.IsNil() {
|
|
// Do not turn nil slices into empty slices. go/ast
|
|
// guarantees that certain lists will be nil if not
|
|
// populated.
|
|
return reflect.Zero(p.Type())
|
|
}
|
|
v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
|
|
for i := 0; i < p.Len(); i++ {
|
|
v.Index(i).Set(subst(m, p.Index(i), pos))
|
|
}
|
|
return v
|
|
|
|
case reflect.Struct:
|
|
v := reflect.New(p.Type()).Elem()
|
|
for i := 0; i < p.NumField(); i++ {
|
|
v.Field(i).Set(subst(m, p.Field(i), pos))
|
|
}
|
|
return v
|
|
|
|
case reflect.Pointer:
|
|
v := reflect.New(p.Type()).Elem()
|
|
if elem := p.Elem(); elem.IsValid() {
|
|
v.Set(subst(m, elem, pos).Addr())
|
|
}
|
|
return v
|
|
|
|
case reflect.Interface:
|
|
v := reflect.New(p.Type()).Elem()
|
|
if elem := p.Elem(); elem.IsValid() {
|
|
v.Set(subst(m, elem, pos))
|
|
}
|
|
return v
|
|
}
|
|
|
|
return pattern
|
|
}
|