77 "errors"
88 "fmt"
99 "go/format"
10+ "path/filepath"
1011 "strings"
1112 "text/template"
1213
@@ -126,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
126127 }
127128
128129 if options .OmitUnusedStructs {
129- enums , structs = filterUnusedStructs (enums , structs , queries )
130+ enums , structs = filterUnusedStructs (options , enums , structs , queries )
130131 }
131132
132133 if err := validate (options , enums , structs , queries ); err != nil {
@@ -216,6 +217,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
216217 "imports" : i .Imports ,
217218 "hasImports" : i .HasImports ,
218219 "hasPrefix" : strings .HasPrefix ,
220+ "trimPrefix" : strings .TrimPrefix ,
219221
220222 // These methods are Go specific, they do not belong in the codegen package
221223 // (as that is language independent)
@@ -237,14 +239,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
237239
238240 output := map [string ]string {}
239241
240- execute := func (name , templateName string ) error {
242+ execute := func (name , packageName , templateName string ) error {
241243 imports := i .Imports (name )
242244 replacedQueries := replaceConflictedArg (imports , queries )
243245
244246 var b bytes.Buffer
245247 w := bufio .NewWriter (& b )
246248 tctx .SourceName = name
247249 tctx .GoQueries = replacedQueries
250+ tctx .Package = packageName
248251 err := tmpl .ExecuteTemplate (w , templateName , & tctx )
249252 w .Flush ()
250253 if err != nil {
@@ -256,8 +259,13 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
256259 return fmt .Errorf ("source error: %w" , err )
257260 }
258261
259- if templateName == "queryFile" && options .OutputFilesSuffix != "" {
260- name += options .OutputFilesSuffix
262+ if templateName == "queryFile" {
263+ if options .OutputQueryFilesDirectory != "" {
264+ name = filepath .Join (options .OutputQueryFilesDirectory , name )
265+ }
266+ if options .OutputFilesSuffix != "" {
267+ name += options .OutputFilesSuffix
268+ }
261269 }
262270
263271 if ! strings .HasSuffix (name , ".go" ) {
@@ -289,24 +297,29 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
289297 batchFileName = options .OutputBatchFileName
290298 }
291299
292- if err := execute (dbFileName , "dbFile" ); err != nil {
300+ modelsPackageName := options .Package
301+ if options .OutputModelsPackage != "" {
302+ modelsPackageName = options .OutputModelsPackage
303+ }
304+
305+ if err := execute (dbFileName , options .Package , "dbFile" ); err != nil {
293306 return nil , err
294307 }
295- if err := execute (modelsFileName , "modelsFile" ); err != nil {
308+ if err := execute (modelsFileName , modelsPackageName , "modelsFile" ); err != nil {
296309 return nil , err
297310 }
298311 if options .EmitInterface {
299- if err := execute (querierFileName , "interfaceFile" ); err != nil {
312+ if err := execute (querierFileName , options . Package , "interfaceFile" ); err != nil {
300313 return nil , err
301314 }
302315 }
303316 if tctx .UsesCopyFrom {
304- if err := execute (copyfromFileName , "copyfromFile" ); err != nil {
317+ if err := execute (copyfromFileName , options . Package , "copyfromFile" ); err != nil {
305318 return nil , err
306319 }
307320 }
308321 if tctx .UsesBatch {
309- if err := execute (batchFileName , "batchFile" ); err != nil {
322+ if err := execute (batchFileName , options . Package , "batchFile" ); err != nil {
310323 return nil , err
311324 }
312325 }
@@ -317,7 +330,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
317330 }
318331
319332 for source := range files {
320- if err := execute (source , "queryFile" ); err != nil {
333+ if err := execute (source , options . Package , "queryFile" ); err != nil {
321334 return nil , err
322335 }
323336 }
@@ -367,7 +380,7 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
367380 return nil
368381}
369382
370- func filterUnusedStructs (enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
383+ func filterUnusedStructs (options * opts. Options , enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
371384 keepTypes := make (map [string ]struct {})
372385
373386 for _ , query := range queries {
@@ -394,16 +407,23 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
394407
395408 keepEnums := make ([]Enum , 0 , len (enums ))
396409 for _ , enum := range enums {
397- _ , keep := keepTypes [enum .Name ]
398- _ , keepNull := keepTypes ["Null" + enum .Name ]
410+ var enumType string
411+ if options .ModelsPackageImportPath != "" {
412+ enumType = options .OutputModelsPackage + "." + enum .Name
413+ } else {
414+ enumType = enum .Name
415+ }
416+
417+ _ , keep := keepTypes [enumType ]
418+ _ , keepNull := keepTypes ["Null" + enumType ]
399419 if keep || keepNull {
400420 keepEnums = append (keepEnums , enum )
401421 }
402422 }
403423
404424 keepStructs := make ([]Struct , 0 , len (structs ))
405425 for _ , st := range structs {
406- if _ , ok := keepTypes [st .Name ]; ok {
426+ if _ , ok := keepTypes [st .Type () ]; ok {
407427 keepStructs = append (keepStructs , st )
408428 }
409429 }
0 commit comments