From a7a70724e0445ec03183b5e73b34841e9dd61991 Mon Sep 17 00:00:00 2001 From: Laurent Voisin Date: Sun, 19 Nov 2017 18:32:14 +0100 Subject: [PATCH] Add support for import paths (fixes #10) Add option "-I" for specifying import paths to be added in the generated code. Several such options can be specified. Add corresponding tests. --- main.go | 15 +++++++++ template.go | 34 ++++++++++++++++++++ template_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) diff --git a/main.go b/main.go index 4a38fa1..a55a2a3 100644 --- a/main.go +++ b/main.go @@ -74,6 +74,8 @@ func main() { log.SetPrefix("") flag.Usage = usage + var imports paths + flag.Var(&imports, "I", "additional `path` to import") flag.Parse() // verify that *outfile contains exactly one occurrence of the %v verb @@ -95,5 +97,18 @@ func main() { } t := newTemplate(cwd, args[0], args[1]) + t.addImports(imports) t.instantiate() } + +// Parsing of -I flags +type paths []string + +func (ps *paths) Set(arg string) error { + *ps = append(*ps, arg) + return nil +} + +func (ps *paths) String() string { + return fmt.Sprint([]string(*ps)) +} diff --git a/template.go b/template.go index 1861c48..d456299 100644 --- a/template.go +++ b/template.go @@ -27,6 +27,7 @@ type template struct { Args []string NewPackage string Dir string + importPaths []string templateName string templateArgs []string mappings map[string]string @@ -57,6 +58,11 @@ func newTemplate(dir, pkg, templateArgsString string) *template { } } +// Add import paths +func (t *template) addImports(paths []string) { + t.importPaths = paths +} + // Add a mapping for identifier func (t *template) addMapping(name string) { replacementName := "" @@ -192,6 +198,34 @@ func (t *template) parse(inputFile string) { // Find names which need to be adjusted namesToMangle := []string{} newDecls := []ast.Decl{} + + // Insert additional imports + if len(t.importPaths) != 0 { + // Use a fake position close to the package declaration to make sure + // the import clause is near the beginning of the file. + pos := f.Package + 10 + specs := make([]ast.Spec, len(t.importPaths)) + for i, path := range t.importPaths { + debugf("Adding import path %q", path) + spec := &ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: fmt.Sprintf("%q", path), + }, + EndPos: pos, + } + specs[i] = spec + f.Imports = append(f.Imports, spec) + } + decl := &ast.GenDecl{ + Tok: token.IMPORT, + Lparen: pos, + Specs: specs, + Rparen: pos, + } + newDecls = append(newDecls, decl) + } + for _, Decl := range f.Decls { remove := false switch d := Decl.(type) { diff --git a/template_test.go b/template_test.go index 064afad..e27c9ed 100644 --- a/template_test.go +++ b/template_test.go @@ -18,6 +18,7 @@ type TestTemplate struct { args string pkg string in string + imports []string outName string out string } @@ -346,6 +347,86 @@ type ( type ( importantType3Tmpl struct{} ) +`, + }, + { + title: "Test with one import", + args: "MySet(time.Duration)", + pkg: "main", + in: basicTest, + imports: []string{"time"}, + outName: "gotemplate_MySet.go", + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main + +import ( + "time" +) + +import "fmt" + +// template type Set(A) + +type MySet struct{ a time.Duration } + +func NewMySet(a time.Duration) time.Duration { return time.Duration(0) } +func NewSizedMySet(a time.Duration) time.Duration { return time.Duration(1) } +func UtilityFunc1MySet() {} +func utilityFuncMySet() {} +func (a time.Duration) f0() {} +func (a *time.Duration) F1() {} + +var AVar1MySet int +var aVar2MySet int +var ( + AVar3MySet int + aVar4MySet int +) +`, + }, + { + title: "Test with two imports", + args: "myPair(time.Duration, aes.KeySizeError)", + pkg: "main", + in: `package tt + +// template type Pair(A, B) +type A int +type B int + +type Pair struct { + a A + b B +} + +func NewPair(a A, b B) Pair { return Pair(a, b) } + +func (p Pair) left() A { return p.a } +func (p Pair) right() B { return p.b } +`, + imports: []string{"time", "crypto/aes"}, + outName: "gotemplate_myPair.go", + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main + +import ( + "crypto/aes" + "time" +) + +// template type Pair(A, B) + +type myPair struct { + a time.Duration + b aes.KeySizeError +} + +func newMyPair(a time.Duration, b aes.KeySizeError) myPair { return myPair(a, b) } + +func (p myPair) left() time.Duration { return p.a } +func (p myPair) right() aes.KeySizeError { return p.b } `, }, } @@ -418,6 +499,7 @@ func testTemplate(t *testing.T, test *TestTemplate) { // Instantiate template template := newTemplate(output, "input", test.args) + template.addImports(test.imports) template.instantiate() // Check output