Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func main() {
log.SetPrefix("")

flag.Usage = usage
var imports paths
flag.Var(&imports, "I", "additional `path` to import")
flag.Parse()

// verify that *outfile contains exactly one occurrence of the %v verb
Expand All @@ -95,5 +97,18 @@ func main() {
}

t := newTemplate(cwd, args[0], args[1])
t.addImports(imports)
t.instantiate()
}

// Parsing of -I flags
type paths []string

func (ps *paths) Set(arg string) error {
*ps = append(*ps, arg)
return nil
}

func (ps *paths) String() string {
return fmt.Sprint([]string(*ps))
}
34 changes: 34 additions & 0 deletions template.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type template struct {
Args []string
NewPackage string
Dir string
importPaths []string
templateName string
templateArgs []string
mappings map[string]string
Expand Down Expand Up @@ -57,6 +58,11 @@ func newTemplate(dir, pkg, templateArgsString string) *template {
}
}

// Add import paths
func (t *template) addImports(paths []string) {
t.importPaths = paths
}

// Add a mapping for identifier
func (t *template) addMapping(name string) {
replacementName := ""
Expand Down Expand Up @@ -192,6 +198,34 @@ func (t *template) parse(inputFile string) {
// Find names which need to be adjusted
namesToMangle := []string{}
newDecls := []ast.Decl{}

// Insert additional imports
if len(t.importPaths) != 0 {
// Use a fake position close to the package declaration to make sure
// the import clause is near the beginning of the file.
pos := f.Package + 10
specs := make([]ast.Spec, len(t.importPaths))
for i, path := range t.importPaths {
debugf("Adding import path %q", path)
spec := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", path),
},
EndPos: pos,
}
specs[i] = spec
f.Imports = append(f.Imports, spec)
}
decl := &ast.GenDecl{
Tok: token.IMPORT,
Lparen: pos,
Specs: specs,
Rparen: pos,
}
newDecls = append(newDecls, decl)
}

for _, Decl := range f.Decls {
remove := false
switch d := Decl.(type) {
Expand Down
82 changes: 82 additions & 0 deletions template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TestTemplate struct {
args string
pkg string
in string
imports []string
outName string
out string
}
Expand Down Expand Up @@ -346,6 +347,86 @@ type (
type (
importantType3Tmpl struct{}
)
`,
},
{
title: "Test with one import",
args: "MySet(time.Duration)",
pkg: "main",
in: basicTest,
imports: []string{"time"},
outName: "gotemplate_MySet.go",
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

import (
"time"
)

import "fmt"

// template type Set(A)

type MySet struct{ a time.Duration }

func NewMySet(a time.Duration) time.Duration { return time.Duration(0) }
func NewSizedMySet(a time.Duration) time.Duration { return time.Duration(1) }
func UtilityFunc1MySet() {}
func utilityFuncMySet() {}
func (a time.Duration) f0() {}
func (a *time.Duration) F1() {}

var AVar1MySet int
var aVar2MySet int
var (
AVar3MySet int
aVar4MySet int
)
`,
},
{
title: "Test with two imports",
args: "myPair(time.Duration, aes.KeySizeError)",
pkg: "main",
in: `package tt

// template type Pair(A, B)
type A int
type B int

type Pair struct {
a A
b B
}

func NewPair(a A, b B) Pair { return Pair(a, b) }

func (p Pair) left() A { return p.a }
func (p Pair) right() B { return p.b }
`,
imports: []string{"time", "crypto/aes"},
outName: "gotemplate_myPair.go",
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

import (
"crypto/aes"
"time"
)

// template type Pair(A, B)

type myPair struct {
a time.Duration
b aes.KeySizeError
}

func newMyPair(a time.Duration, b aes.KeySizeError) myPair { return myPair(a, b) }

func (p myPair) left() time.Duration { return p.a }
func (p myPair) right() aes.KeySizeError { return p.b }
`,
},
}
Expand Down Expand Up @@ -418,6 +499,7 @@ func testTemplate(t *testing.T, test *TestTemplate) {

// Instantiate template
template := newTemplate(output, "input", test.args)
template.addImports(test.imports)
template.instantiate()

// Check output
Expand Down