recursion.go 5.39 KB
Newer Older
Etienne Renault's avatar
Etienne Renault committed
1
2
3
4
5
// Copyright (C) 2020 Laboratoire de Recherche et Developpement
// de l'EPITA (LRDE).
//
// This file is part of Go2Pins, a tool for Golang model-checking
//
Etienne Renault's avatar
Etienne Renault committed
6
// Go2Pins is free software; you can redistribute it and/or modify it
Etienne Renault's avatar
Etienne Renault committed
7
8
9
10
// under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 3 of the License, or
// (at your option) any later version.
//
Etienne Renault's avatar
Etienne Renault committed
11
// Go2Pins is distributed in the hope that it will be useful, but WITHOUT
Etienne Renault's avatar
Etienne Renault committed
12
13
14
15
16
17
18
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
// or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
// License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

19
20
21
package tools

import (
22
	"bytes"
23
	"fmt"
24
25
26
27
28
	"go/ast"
	"go/parser"
	"go/printer"
	"go/token"
	"log"
29
30
31
	"strconv"
)

32
33
34
35
36
type FuncInfo struct {
	Name                  string
	IsRecursive           bool
	TransitivelyRecursive bool
	Calls                 []string
37
38
}

39
40
41
42
43
44
45
46
47
func collectFuncInfo(node *ast.FuncDecl) *FuncInfo {
	// Check wether the function is recursive
	var is_recursive = false
	calls := []string{}
	ast.Inspect(node, func(n ast.Node) bool {
		callexp, ok := n.(*ast.CallExpr)
		if !ok {
			return true
		}
48

49
50
		switch f := callexp.Fun.(type) {
		case *ast.Ident:
51
52
53
54
55
56
			// FIXME this is a cast operation
			// only cast to int is currently supported
			if f.Name == "int" {
				return true
			}

57
58
			if node.Name.Name == f.Name {
				is_recursive = true
59
			}
60

61
62
63
64
65
66
67
68
69
70
			ok := false
			for _, val := range calls {
				if val == f.Name {
					ok = true
					break
				}
			}
			if !ok {
				calls = append(calls, f.Name)
			}
71
72
		}

73
74
		return true
	})
75

76
77
78
	res := FuncInfo{node.Name.Name, is_recursive, false /* Unknown at this point*/, calls}
	return &res
}
79

80
81
func BuildCallgraph(inputfile string) map[string]*FuncInfo {
	fset := token.NewFileSet()
82

83
84
85
86
	node, err := parser.ParseFile(fset, "src.go", inputfile, 0)
	if err != nil {
		log.Fatal(err)
	}
87

88
89
90
91
92
93
94
95
	callgraph := make(map[string]*FuncInfo)
	// Walk the file to compute relevant information about recursion
	ast.Inspect(node, func(n ast.Node) bool {
		// check if we're on a `go` node
		decl, ok := n.(*ast.FuncDecl)
		if !ok {
			return true
		}
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
		// FIXME HANDLE GOROUTINE + LAMBDA Function

		// get position in file
		callgraph[decl.Name.Name] = collectFuncInfo(decl)
		return true
	})

	// Fill the transitively recursive field
	for key, element := range callgraph {
		visited := make(map[string]bool)
		trans_rec := false
		var indirect func(string, string)
		indirect = func(from string, lookup string) {
			visited[from] = true
			for _, call := range callgraph[from].Calls {
				if lookup == call {
					trans_rec = true
114
				}
115
116
117
				_, ok := visited[call]
				if !ok {
					indirect(call, lookup)
hmoreau's avatar
hmoreau committed
118
				}
119
120
			}
		}
121
122
123
124
125
126
127
128
129
130
131
132
		indirect(element.Name, element.Name)
		callgraph[key].TransitivelyRecursive = trans_rec
	}
	return callgraph
}

func Callgraph(inputfile string) {
	callgraph := BuildCallgraph(inputfile)
	digraph := "digraph callgraph {\n"
	for key, element := range callgraph {
		for _, call := range element.Calls {
			digraph += "  " + key + " -> " + call + ";\n"
133
134
		}
	}
135
136
	digraph += "}\n"
	fmt.Println(digraph)
137
138
}

139
140
141
142
143
144
145
146
func parseFunc(filename, functionname string) (fun *ast.FuncDecl, fset *token.FileSet) {
	fset = token.NewFileSet()
	if file, err := parser.ParseFile(fset, "src.go", filename, 0); err == nil {
		for _, d := range file.Decls {
			if f, ok := d.(*ast.FuncDecl); ok && f.Name.Name == functionname {
				fun = f
				return
			}
147
148
		}
	}
149
	panic("function not found " + functionname)
150
151
}

152
153
func RewriteRecursion(inputfile string, unroll_level int) (string, bool) {
	fset := token.NewFileSet()
154

155
156
157
	node, err := parser.ParseFile(fset, "src.go", inputfile, 0)
	if err != nil {
		log.Fatal(err)
158
159
	}

160
161
162
	result := ""
	rec := false
	for key, element := range BuildCallgraph(inputfile) {
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
		if !(element.IsRecursive || element.TransitivelyRecursive) {
			continue
		}

		rec = true

		// FIXME modify internal calls

		for i := 1; i <= unroll_level; i++ {
			funcAST, fset := parseFunc(inputfile, key)
			funcAST.Name.Name = funcAST.Name.Name + "_" +
				strconv.Itoa(i)

			if i != unroll_level {
				ast.Inspect(funcAST.Body, func(n ast.Node) bool {
					callexp, ok := n.(*ast.CallExpr)
					if !ok {
						return true
					}

					switch f := callexp.Fun.(type) {
					case *ast.Ident:
						f.Name = f.Name + "_" + strconv.Itoa(i+1)
					}

					return true
				})
191
			} else {
192
193
194
195
196
197
198
199
200
201
202
203
204
205
				funcAST.Body.List = []ast.Stmt{}

				funcAST.Body.List = append(funcAST.Body.List,
					&ast.ExprStmt{
						X: &ast.CallExpr{
							Fun: ast.NewIdent("panic"),
							Args: []ast.Expr{
								&ast.BasicLit{
									Kind:  token.STRING,
									Value: "\"Max Depth Reached\"",
								},
							},
						},
					})
206
			}
207
208
209
210

			var buf bytes.Buffer
			printer.Fprint(&buf, fset, funcAST)
			result = result + "\n" + buf.String() + "\n"
211
		}
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

		ast.Inspect(node, func(n ast.Node) bool {
			// check if we're on a `go` node
			decl, ok := n.(*ast.FuncDecl)
			if !ok {
				return true
			}
			if decl.Name.Name != key {
				return true
			}
			ast.Inspect(decl.Body, func(n ast.Node) bool {
				callexp, ok := n.(*ast.CallExpr)
				if !ok {
					return true
				}

				switch f := callexp.Fun.(type) {
				case *ast.Ident:
					f.Name = f.Name + "_1"
				}

				return true
			})

			return true
		})

239
	}
240
241
242
243
244
245

	var buf bytes.Buffer
	printer.Fprint(&buf, fset, node)
	result = buf.String() + result

	return result, rec
246
}