Skip to main content
โšก Calmops

AST Manipulation and Analysis in Go

AST Manipulation and Analysis in Go

Introduction

The Abstract Syntax Tree (AST) is a tree representation of the syntactic structure of source code. Go’s ast package provides powerful tools for parsing, analyzing, and manipulating Go code programmatically. This enables you to build code analysis tools, linters, code generators, and refactoring utilities.

In this guide, you’ll learn how to parse Go code into AST, traverse and analyze the tree, and transform code programmatically. We’ll cover practical examples from simple analysis to complex code generation.

Core Concepts

What is an AST?

An AST represents the structure of source code as a tree where:

  • Nodes represent language constructs (functions, variables, expressions)
  • Edges represent relationships between constructs
  • Leaves represent terminal symbols (identifiers, literals)

Example AST for x := 5 + 3:

AssignStmt
โ”œโ”€โ”€ Lhs: Ident(x)
โ””โ”€โ”€ Rhs: BinaryExpr
    โ”œโ”€โ”€ X: BasicLit(5)
    โ”œโ”€โ”€ Op: +
    โ””โ”€โ”€ Y: BasicLit(3)

Why Use AST?

  • Code Analysis: Understand code structure and patterns
  • Linting: Detect code quality issues
  • Refactoring: Automatically transform code
  • Code Generation: Generate code from specifications
  • Documentation: Extract and analyze documentation
  • Optimization: Identify optimization opportunities

Go AST Package Structure

The ast package provides:

  • Parsing: parser.ParseFile(), parser.ParseExpr()
  • Traversal: ast.Walk(), ast.Inspect()
  • Printing: ast.Print(), format.Node()
  • Manipulation: Direct node modification

Good: Parsing and Analyzing Go Code

Basic AST Parsing

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// โœ… GOOD: Parse Go source code into AST
func ParseGoFile(filename string) (*ast.File, error) {
	fset := token.NewFileSet()
	
	file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
	if err != nil {
		return nil, fmt.Errorf("parse error: %w", err)
	}
	
	return file, nil
}

// โœ… GOOD: Parse Go code from string
func ParseGoCode(code string) (*ast.File, error) {
	fset := token.NewFileSet()
	
	file, err := parser.ParseFile(fset, "main.go", code, parser.AllErrors)
	if err != nil {
		return nil, fmt.Errorf("parse error: %w", err)
	}
	
	return file, nil
}

func main() {
	code := `
package main

import "fmt"

func main() {
	fmt.Println("Hello, World!")
}
`
	
	file, err := ParseGoCode(code)
	if err != nil {
		fmt.Println(err)
		return
	}
	
	fmt.Printf("Package: %s\n", file.Name.Name)
}

Traversing the AST

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// โœ… GOOD: Traverse AST using ast.Inspect
func AnalyzeFunctions(code string) {
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	
	ast.Inspect(file, func(node ast.Node) bool {
		// Check if node is a function declaration
		if fn, ok := node.(*ast.FuncDecl); ok {
			fmt.Printf("Function: %s\n", fn.Name.Name)
			
			// Analyze parameters
			if fn.Type.Params != nil {
				for _, param := range fn.Type.Params.List {
					for _, name := range param.Names {
						fmt.Printf("  Parameter: %s\n", name.Name)
					}
				}
			}
		}
		return true
	})
}

// โœ… GOOD: Custom AST visitor
type FunctionVisitor struct {
	functions []string
}

func (v *FunctionVisitor) Visit(node ast.Node) ast.Visitor {
	if fn, ok := node.(*ast.FuncDecl); ok {
		v.functions = append(v.functions, fn.Name.Name)
	}
	return v
}

func main() {
	code := `
package main

func Add(a, b int) int {
	return a + b
}

func Multiply(a, b int) int {
	return a * b
}
`
	
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	
	visitor := &FunctionVisitor{}
	ast.Walk(visitor, file)
	
	fmt.Printf("Functions found: %v\n", visitor.functions)
}

Finding Specific Code Patterns

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// โœ… GOOD: Find all function calls
func FindFunctionCalls(code string) []string {
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	
	var calls []string
	
	ast.Inspect(file, func(node ast.Node) bool {
		if call, ok := node.(*ast.CallExpr); ok {
			if ident, ok := call.Fun.(*ast.Ident); ok {
				calls = append(calls, ident.Name)
			}
		}
		return true
	})
	
	return calls
}

// โœ… GOOD: Find all variable declarations
func FindVariables(code string) map[string]string {
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	
	vars := make(map[string]string)
	
	ast.Inspect(file, func(node ast.Node) bool {
		if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
			for _, spec := range decl.Specs {
				if vspec, ok := spec.(*ast.ValueSpec); ok {
					for _, name := range vspec.Names {
						vars[name.Name] = fmt.Sprintf("%v", vspec.Type)
					}
				}
			}
		}
		return true
	})
	
	return vars
}

// โœ… GOOD: Find unused variables
func FindUnusedVariables(code string) []string {
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	
	declared := make(map[string]bool)
	used := make(map[string]bool)
	
	ast.Inspect(file, func(node ast.Node) bool {
		// Track declarations
		if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
			for _, spec := range decl.Specs {
				if vspec, ok := spec.(*ast.ValueSpec); ok {
					for _, name := range vspec.Names {
						declared[name.Name] = true
					}
				}
			}
		}
		
		// Track usage
		if ident, ok := node.(*ast.Ident); ok {
			used[ident.Name] = true
		}
		
		return true
	})
	
	var unused []string
	for name := range declared {
		if !used[name] {
			unused = append(unused, name)
		}
	}
	
	return unused
}

func main() {
	code := `
package main

import "fmt"

func main() {
	x := 10
	y := 20
	fmt.Println(x)
}
`
	
	fmt.Println("Function calls:", FindFunctionCalls(code))
	fmt.Println("Variables:", FindVariables(code))
	fmt.Println("Unused variables:", FindUnusedVariables(code))
}

Bad: Inefficient AST Analysis

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// โŒ BAD: Parsing the same file multiple times
func BadAnalysis(filename string) {
	// Parse for functions
	fset1 := token.NewFileSet()
	file1, _ := parser.ParseFile(fset1, filename, nil, 0)
	
	// Parse for variables
	fset2 := token.NewFileSet()
	file2, _ := parser.ParseFile(fset2, filename, nil, 0)
	
	// Parse for calls
	fset3 := token.NewFileSet()
	file3, _ := parser.ParseFile(fset3, filename, nil, 0)
	
	// Redundant parsing!
}

// โŒ BAD: Inefficient tree traversal
func BadTraversal(file *ast.File) {
	// Traversing multiple times for different analyses
	for i := 0; i < 3; i++ {
		ast.Inspect(file, func(node ast.Node) bool {
			// Do analysis
			return true
		})
	}
}

// โŒ BAD: No error handling
func BadParsing(code string) *ast.File {
	fset := token.NewFileSet()
	file, _ := parser.ParseFile(fset, "main.go", code, 0)
	return file // Ignores parse errors
}

Advanced Patterns

Building a Custom Linter

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// Issue represents a linting issue
type Issue struct {
	Line    int
	Column  int
	Message string
}

// Linter performs code analysis
type Linter struct {
	issues []Issue
	fset   *token.FileSet
}

// CheckUnusedVariables finds unused variables
func (l *Linter) CheckUnusedVariables(file *ast.File) {
	declared := make(map[string]ast.Node)
	used := make(map[string]bool)
	
	ast.Inspect(file, func(node ast.Node) bool {
		// Track declarations
		if decl, ok := node.(*ast.GenDecl); ok && decl.Tok == token.VAR {
			for _, spec := range decl.Specs {
				if vspec, ok := spec.(*ast.ValueSpec); ok {
					for _, name := range vspec.Names {
						declared[name.Name] = node
					}
				}
			}
		}
		
		// Track usage
		if ident, ok := node.(*ast.Ident); ok {
			used[ident.Name] = true
		}
		
		return true
	})
	
	// Report unused
	for name, node := range declared {
		if !used[name] {
			pos := l.fset.Position(node.Pos())
			l.issues = append(l.issues, Issue{
				Line:    pos.Line,
				Column:  pos.Column,
				Message: fmt.Sprintf("unused variable: %s", name),
			})
		}
	}
}

// CheckComplexFunctions finds overly complex functions
func (l *Linter) CheckComplexFunctions(file *ast.File) {
	ast.Inspect(file, func(node ast.Node) bool {
		if fn, ok := node.(*ast.FuncDecl); ok {
			complexity := l.calculateComplexity(fn.Body)
			if complexity > 10 {
				pos := l.fset.Position(fn.Pos())
				l.issues = append(l.issues, Issue{
					Line:    pos.Line,
					Column:  pos.Column,
					Message: fmt.Sprintf("function too complex: %d", complexity),
				})
			}
		}
		return true
	})
}

func (l *Linter) calculateComplexity(stmt ast.Stmt) int {
	complexity := 1
	
	ast.Inspect(stmt, func(node ast.Node) bool {
		switch node.(type) {
		case *ast.IfStmt, *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt:
			complexity++
		}
		return true
	})
	
	return complexity
}

func (l *Linter) Lint(filename string) []Issue {
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, filename, nil, 0)
	if err != nil {
		return nil
	}
	
	l.fset = fset
	l.CheckUnusedVariables(file)
	l.CheckComplexFunctions(file)
	
	return l.issues
}

func main() {
	code := `
package main

func main() {
	x := 10
	if true {
		if true {
			if true {
				if true {
					if true {
						if true {
							if true {
								if true {
									if true {
										if true {
											if true {
											}
										}
									}
								}
							}
						}
					}
				}
			}
		}
	}
}
`
	
	linter := &Linter{}
	issues := linter.Lint("main.go")
	
	for _, issue := range issues {
		fmt.Printf("Line %d: %s\n", issue.Line, issue.Message)
	}
}

Code Generation from AST

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
)

// โœ… GOOD: Generate code from AST
func GenerateStringer(structName string, fields []string) string {
	// Create a new file
	file := &ast.File{
		Name: ast.NewIdent("main"),
	}
	
	// Create String method
	method := &ast.FuncDecl{
		Name: ast.NewIdent("String"),
		Type: &ast.FuncType{
			Params: &ast.FieldList{
				List: []*ast.Field{
					{
						Names: []*ast.Ident{ast.NewIdent("s")},
						Type:  ast.NewIdent(structName),
					},
				},
			},
			Results: &ast.FieldList{
				List: []*ast.Field{
					{
						Type: ast.NewIdent("string"),
					},
				},
			},
		},
		Body: &ast.BlockStmt{
			List: []ast.Stmt{
				&ast.ReturnStmt{
					Results: []ast.Expr{
						&ast.BasicLit{
							Kind:  token.STRING,
							Value: `"generated"`,
						},
					},
				},
			},
		},
	}
	
	file.Decls = append(file.Decls, method)
	
	// Format and return
	var buf bytes.Buffer
	format.Node(&buf, token.NewFileSet(), file)
	return buf.String()
}

func main() {
	code := GenerateStringer("Person", []string{"Name", "Age"})
	fmt.Println(code)
}

Best Practices

1. Reuse FileSet

// โœ… GOOD: Single FileSet for multiple files
fset := token.NewFileSet()
file1, _ := parser.ParseFile(fset, "file1.go", nil, 0)
file2, _ := parser.ParseFile(fset, "file2.go", nil, 0)

// โŒ BAD: Multiple FileSets
fset1 := token.NewFileSet()
file1, _ := parser.ParseFile(fset1, "file1.go", nil, 0)
fset2 := token.NewFileSet()
file2, _ := parser.ParseFile(fset2, "file2.go", nil, 0)

2. Handle Parse Errors

// โœ… GOOD: Check for errors
file, err := parser.ParseFile(fset, "main.go", code, parser.AllErrors)
if err != nil {
	fmt.Printf("Parse error: %v\n", err)
	return
}

// โŒ BAD: Ignore errors
file, _ := parser.ParseFile(fset, "main.go", code, 0)

3. Use Appropriate Visitor Pattern

// โœ… GOOD: Use ast.Inspect for simple traversal
ast.Inspect(file, func(node ast.Node) bool {
	// Process node
	return true
})

// โœ… GOOD: Use ast.Walk for custom visitor
type MyVisitor struct{}
func (v *MyVisitor) Visit(node ast.Node) ast.Visitor {
	// Process node
	return v
}
ast.Walk(&MyVisitor{}, file)

Resources

Summary

AST manipulation is a powerful technique for building code analysis tools, linters, and generators. By understanding how to parse, traverse, and analyze Go code, you can:

  • Build linters to enforce code quality
  • Generate code automatically
  • Refactor code programmatically
  • Analyze patterns in your codebase
  • Create development tools that understand Go syntax

Start with simple traversals using ast.Inspect(), then progress to custom visitors and code generation as your needs grow.

Comments