【问题标题】:How can I compare two source code files/ ast trees?如何比较两个源代码文件/ ast 树?
【发布时间】:2015-05-12 16:10:37
【问题描述】:

我正在使用templates 包生成一些源代码(有更好的方法吗?),部分测试我需要检查输出是否与预期的源代码匹配。

  • 我尝试了字符串比较,但由于模板包生成的额外空格/新行而失败。我也尝试过format.Source,但没有成功。 (失败)
  • 我尝试解析两个来源的 ast(见下文),但即使代码除了新行/空格之外基本相同,ast 也不匹配。 (失败)

    主包

    import (
        "fmt"
        "go/parser"
        "go/token"
        "reflect"
    )
    
    func main() {
        stub1 := `package main
         func myfunc(s string) error {
            return nil  
        }`
        stub2 := `package main
    
         func myfunc(s string) error {
    
            return nil
    
        }`
        fset := token.NewFileSet()
        r1, err := parser.ParseFile(fset, "", stub1, parser.AllErrors)
        if err != nil {
            panic(err)
        }
        fset = token.NewFileSet()
        r2, err := parser.ParseFile(fset, "", stub2, parser.AllErrors)
        if err != nil {
            panic(err)
        }
        if !reflect.DeepEqual(r1, r2) {
            fmt.Printf("e %v, r %s, ", r1, r2)
        }
    }
    

Playground

【问题讨论】:

  • 你想比较任意树,还是只去你已经解析的树?
  • 只是去源代码/树,因此是去标签

标签: parsing go compare abstract-syntax-tree


【解决方案1】:

嗯,实现此目的的一种简单方法是使用 go/printer 库,它可以让您更好地控制输出格式,基本上就像在源代码上运行 gofmt 一样,标准化两个树:

package main
import (
    "fmt"
    "go/parser"
    "go/token"
    "go/printer"
    //"reflect"
    "bytes"
)

func main() {
    stub1 := `package main
     func myfunc(s string) error {
        return nil  
    }`
    stub2 := `package main

     func myfunc(s string) error {

        return nil

    }`

    fset1 := token.NewFileSet()
    r1, err := parser.ParseFile(fset1, "", stub1, parser.AllErrors)
    if err != nil {
        panic(err)
    }
    fset2 := token.NewFileSet()
    r2, err := parser.ParseFile(fset1, "", stub2, parser.AllErrors)
    if err != nil {
        panic(err)
    }

    // we create two output buffers for each source tree
    out1 := bytes.NewBuffer(nil)
    out2 := bytes.NewBuffer(nil)

    // we use the same printer config for both
    conf := &printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}

    // print to both outputs
    if err := conf.Fprint(out1, fset1, r1); err != nil {
        panic(err)
    }
    if err := conf.Fprint(out2, fset2, r2); err != nil {
        panic(err)
    }


    // they should be identical!
    if string(out1.Bytes()) != string(out2.Bytes()) {
        panic(string(out1.Bytes()) +"\n" + string(out2.Bytes()))
    } else {
        fmt.Println("A-OKAY!")
    }
}

当然,这段代码需要重构才能看起来不那么愚蠢。另一种方法是不使用 DeepEqual,而是自己创建一个树比较函数,它会跳过不相关的节点。

【讨论】:

  • 比较功能可能不是微不足道的,所以我试图避免这种情况。打印机似乎在更“复杂”的结构上失败play.golang.org/p/I9cAVEYLAm
  • @mihai 可能会遍历树并过滤它,然后使用 DeepEqual?
  • 知道如何过滤不相关的节点吗?我试图删除 nil 节点,假设它们代表新的行/空格,但似乎并非如此。 play.golang.org/p/JVVpKIzela
  • A nit: bytes.Buffer 有一个 String() 方法,因此您不需要所有这些转换。
【解决方案2】:

这比我想象的要容易。我所要做的就是删除空的新行(格式化后)。下面是代码。

    package main

    import (
        "fmt"
        "go/format"
        "strings"
    )

    func main() {
        a, err := fmtSource(stub1)
        if err != nil {
            panic(err)
        }
        b, err := fmtSource(stub2)
        if err != nil {
            panic(err)
        }
        if a != b {
            fmt.Printf("a %v, \n b %v", a, b)
        }
    }

func fmtSource(source string) (string, error) {
    if !strings.Contains(source, "package") {
        source = "package main\n" + source
    }
    b, err := format.Source([]byte(source))
    if err != nil {
        return "", err
    }
    // cleanLine replaces double space with one space
    cleanLine := func(s string)string{
        sa := strings.Fields(s)
        return strings.Join(sa, " ")
    }
    lines := strings.Split(string(b), "\n")
    n := 0
    var startLn *int
    for _, line := range lines {
        if line != "" {
            line = cleanLine(line)
            lines[n] = line
            if startLn == nil {
                x := n
                startLn = &x
            }
            n++
        }
    }
    lines = lines[*startLn:n]
    // Add final "" entry to get trailing newline from Join.
    if n > 0 && lines[n-1] != "" {
        lines = append(lines, "")
    }


    // Make it pretty 
    b, err = format.Source([]byte(strings.Join(lines, "\n")))
    if err != nil {
        return "", err
    }
    return string(b), nil
}

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2013-10-19
    • 2018-10-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多