AST Manipulation and Analysis in Go
Introduction
The Abstract Syntax Tree (AST) is a tree representation of the syntactic structure of source code. Go’s ast package provides powerful tools for parsing, analyzing, and manipulating Go code programmatically. This enables you to build code analysis tools, linters, code generators, and refactoring utilities.
In this guide, you’ll learn how to parse Go code into AST, traverse and analyze the tree, and transform code programmatically. We’ll cover practical examples from simple analysis to complex code generation.
Core Concepts
What is an AST?
An AST represents the structure of source code as a tree where:
- Nodes represent language constructs (functions, variables, expressions)
- Edges represent relationships between constructs
- Leaves represent terminal symbols (identifiers, literals)
Example AST for x := 5 + 3:
AssignStmt
โโโ Lhs: Ident(x)
โโโ Rhs: BinaryExpr
โโโ X: BasicLit(5)
โโโ Op: +
โโโ Y: BasicLit(3)
Why Use AST?
- Code Analysis: Understand code structure and patterns
- Linting: Detect code quality issues
- Refactoring: Automatically transform code
- Code Generation: Generate code from specifications
- Documentation: Extract and analyze documentation
- Optimization: Identify optimization opportunities
Go AST Package Structure
The ast package provides:
- Parsing:
parser.ParseFile(),parser.ParseExpr() - Traversal:
ast.Walk(),ast.Inspect() - Printing:
ast.Print(),format.Node() - Manipulation: Direct node modification
Good: Parsing and Analyzing Go Code
Basic AST Parsing
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// โ
GOOD: Parse Go source code into AST
func ParseGoFile(filename string) (*ast.File, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
return nil, fmt.Errorf("parse error: %w", err)
}
return file, nil
}
// โ
GOOD: Parse Go code from string
func ParseGoCode(code string) (*ast.File, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "main.go", code, parser.AllErrors)
if err != nil {
return nil, fmt.Errorf("parse error: %w", err)
}
return file, nil
}
func main() {
code := `
package main
import "fmt"
func main() {
fmt.Println("Hello, World!")
}
`
file, err := ParseGoCode(code)
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("Package: %s\n", file.Name.Name)
}
Traversing the AST
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// โ
GOOD: Traverse AST using ast.Inspect
func AnalyzeFunctions(code string) {
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
ast.Inspect(file, func(node ast.Node) bool {
// Check if node is a function declaration
if fn, ok := node.(*ast.FuncDecl); ok {
fmt.Printf("Function: %s\n", fn.Name.Name)
// Analyze parameters
if fn.Type.Params != nil {
for _, param := range fn.Type.Params.List {
for _, name := range param.Names {
fmt.Printf(" Parameter: %s\n", name.Name)
}
}
}
}
return true
})
}
// โ
GOOD: Custom AST visitor
type FunctionVisitor struct {
functions []string
}
func (v *FunctionVisitor) Visit(node ast.Node) ast.Visitor {
if fn, ok := node.(*ast.FuncDecl); ok {
v.functions = append(v.functions, fn.Name.Name)
}
return v
}
func main() {
code := `
package main
func Add(a, b int) int {
return a + b
}
func Multiply(a, b int) int {
return a * b
}
`
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
visitor := &FunctionVisitor{}
ast.Walk(visitor, file)
fmt.Printf("Functions found: %v\n", visitor.functions)
}
Finding Specific Code Patterns
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// โ
GOOD: Find all function calls
func FindFunctionCalls(code string) []string {
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
var calls []string
ast.Inspect(file, func(node ast.Node) bool {
if call, ok := node.(*ast.CallExpr); ok {
if ident, ok := call.Fun.(*ast.Ident); ok {
calls = append(calls, ident.Name)
}
}
return true
})
return calls
}
// โ
GOOD: Find all variable declarations
func FindVariables(code string) map[string]string {
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
vars := make(map[string]string)
ast.Inspect(file, func(node ast.Node) bool {
if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
for _, spec := range decl.Specs {
if vspec, ok := spec.(*ast.ValueSpec); ok {
for _, name := range vspec.Names {
vars[name.Name] = fmt.Sprintf("%v", vspec.Type)
}
}
}
}
return true
})
return vars
}
// โ
GOOD: Find unused variables
func FindUnusedVariables(code string) []string {
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
declared := make(map[string]bool)
used := make(map[string]bool)
ast.Inspect(file, func(node ast.Node) bool {
// Track declarations
if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
for _, spec := range decl.Specs {
if vspec, ok := spec.(*ast.ValueSpec); ok {
for _, name := range vspec.Names {
declared[name.Name] = true
}
}
}
}
// Track usage
if ident, ok := node.(*ast.Ident); ok {
used[ident.Name] = true
}
return true
})
var unused []string
for name := range declared {
if !used[name] {
unused = append(unused, name)
}
}
return unused
}
func main() {
code := `
package main
import "fmt"
func main() {
x := 10
y := 20
fmt.Println(x)
}
`
fmt.Println("Function calls:", FindFunctionCalls(code))
fmt.Println("Variables:", FindVariables(code))
fmt.Println("Unused variables:", FindUnusedVariables(code))
}
Bad: Inefficient AST Analysis
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// โ BAD: Parsing the same file multiple times
func BadAnalysis(filename string) {
// Parse for functions
fset1 := token.NewFileSet()
file1, _ := parser.ParseFile(fset1, filename, nil, 0)
// Parse for variables
fset2 := token.NewFileSet()
file2, _ := parser.ParseFile(fset2, filename, nil, 0)
// Parse for calls
fset3 := token.NewFileSet()
file3, _ := parser.ParseFile(fset3, filename, nil, 0)
// Redundant parsing!
}
// โ BAD: Inefficient tree traversal
func BadTraversal(file *ast.File) {
// Traversing multiple times for different analyses
for i := 0; i < 3; i++ {
ast.Inspect(file, func(node ast.Node) bool {
// Do analysis
return true
})
}
}
// โ BAD: No error handling
func BadParsing(code string) *ast.File {
fset := token.NewFileSet()
file, _ := parser.ParseFile(fset, "main.go", code, 0)
return file // Ignores parse errors
}
Advanced Patterns
Building a Custom Linter
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// Issue represents a linting issue
type Issue struct {
Line int
Column int
Message string
}
// Linter performs code analysis
type Linter struct {
issues []Issue
fset *token.FileSet
}
// CheckUnusedVariables finds unused variables
func (l *Linter) CheckUnusedVariables(file *ast.File) {
declared := make(map[string]ast.Node)
used := make(map[string]bool)
ast.Inspect(file, func(node ast.Node) bool {
// Track declarations
if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
for _, spec := range decl.Specs {
if vspec, ok := spec.(*ast.ValueSpec); ok {
for _, name := range vspec.Names {
declared[name.Name] = node
}
}
}
}
// Track usage
if ident, ok := node.(*ast.Ident); ok {
used[ident.Name] = true
}
return true
})
// Report unused
for name, node := range declared {
if !used[name] {
pos := l.fset.Position(node.Pos())
l.issues = append(l.issues, Issue{
Line: pos.Line,
Column: pos.Column,
Message: fmt.Sprintf("unused variable: %s", name),
})
}
}
}
// CheckComplexFunctions finds overly complex functions
func (l *Linter) CheckComplexFunctions(file *ast.File) {
ast.Inspect(file, func(node ast.Node) bool {
if fn, ok := node.(*ast.FuncDecl); ok {
complexity := l.calculateComplexity(fn.Body)
if complexity > 10 {
pos := l.fset.Position(fn.Pos())
l.issues = append(l.issues, Issue{
Line: pos.Line,
Column: pos.Column,
Message: fmt.Sprintf("function too complex: %d", complexity),
})
}
}
return true
})
}
func (l *Linter) calculateComplexity(stmt ast.Stmt) int {
complexity := 1
ast.Inspect(stmt, func(node ast.Node) bool {
switch node.(type) {
case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt:
complexity++
}
return true
})
return complexity
}
func (l *Linter) Lint(filename string) []Issue {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, 0)
if err != nil {
return nil
}
l.fset = fset
l.CheckUnusedVariables(file)
l.CheckComplexFunctions(file)
return l.issues
}
func main() {
code := `
package main
func main() {
x := 10
if true {
if true {
if true {
if true {
if true {
if true {
if true {
if true {
if true {
if true {
if true {
}
}
}
}
}
}
}
}
}
}
}
}
`
linter := &Linter{}
issues := linter.Lint("main.go")
for _, issue := range issues {
fmt.Printf("Line %d: %s\n", issue.Line, issue.Message)
}
}
Code Generation from AST
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
)
// โ
GOOD: Generate code from AST
func GenerateStringer(structName string, fields []string) string {
// Create a new file
file := &ast.File{
Name: ast.NewIdent("main"),
}
// Create String method
method := &ast.FuncDecl{
Name: ast.NewIdent("String"),
Type: &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("s")},
Type: ast.NewIdent(structName),
},
},
},
Results: &ast.FieldList{
List: []*ast.Field{
{
Type: ast.NewIdent("string"),
},
},
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: `"generated"`,
},
},
},
},
},
}
file.Decls = append(file.Decls, method)
// Format and return
var buf bytes.Buffer
format.Node(&buf, token.NewFileSet(), file)
return buf.String()
}
func main() {
code := GenerateStringer("Person", []string{"Name", "Age"})
fmt.Println(code)
}
Best Practices
1. Reuse FileSet
// โ
GOOD: Single FileSet for multiple files
fset := token.NewFileSet()
file1, _ := parser.ParseFile(fset, "file1.go", nil, 0)
file2, _ := parser.ParseFile(fset, "file2.go", nil, 0)
// โ BAD: Multiple FileSets
fset1 := token.NewFileSet()
file1, _ := parser.ParseFile(fset1, "file1.go", nil, 0)
fset2 := token.NewFileSet()
file2, _ := parser.ParseFile(fset2, "file2.go", nil, 0)
2. Handle Parse Errors
// โ
GOOD: Check for errors
file, err := parser.ParseFile(fset, "main.go", code, parser.AllErrors)
if err != nil {
fmt.Printf("Parse error: %v\n", err)
return
}
// โ BAD: Ignore errors
file, _ := parser.ParseFile(fset, "main.go", code, 0)
3. Use Appropriate Visitor Pattern
// โ
GOOD: Use ast.Inspect for simple traversal
ast.Inspect(file, func(node ast.Node) bool {
// Process node
return true
})
// โ
GOOD: Use ast.Walk for custom visitor
type MyVisitor struct{}
func (v *MyVisitor) Visit(node ast.Node) ast.Visitor {
// Process node
return v
}
ast.Walk(&MyVisitor{}, file)
Resources
- Go AST Package: https://pkg.go.dev/go/ast
- Go Parser Package: https://pkg.go.dev/go/parser
- Go Format Package: https://pkg.go.dev/go/format
- AST Visualization: https://astexplorer.net/
- Writing Go Tools: https://golang.org/cmd/cgo/
Summary
AST manipulation is a powerful technique for building code analysis tools, linters, and generators. By understanding how to parse, traverse, and analyze Go code, you can:
- Build linters to enforce code quality
- Generate code automatically
- Refactor code programmatically
- Analyze patterns in your codebase
- Create development tools that understand Go syntax
Start with simple traversals using ast.Inspect(), then progress to custom visitors and code generation as your needs grow.
Comments