@@ -312,7 +312,43 @@ def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str)
312312 else :
313313 accessors = []
314314 return ql .Stub (name = cls .name , base_import = base_import , import_prefix = generated_import_prefix ,
315- synth_accessors = accessors , ql_internal = "ql_internal" in cls .pragmas )
315+ doc = cls .doc , synth_accessors = accessors ,
316+ ql_internal = "ql_internal" in cls .pragmas )
317+
318+
319+ def _patch_class_qldocs (cls : str , qldoc : str , stub_file : pathlib .Path ):
320+ if not qldoc or not stub_file .exists ():
321+ return
322+ qldoc = "\n " .join (l .rstrip () for l in qldoc .splitlines ())
323+ tmp = stub_file .with_suffix (f'{ stub_file .suffix } .bkp' )
324+ header = "// the following QLdoc is generated: if you need to edit it, do it in the schema file\n "
325+ with open (stub_file ) as input :
326+ qldoc_start = None
327+ qldoc_end = None
328+ class_start = None
329+ for lineno , line in enumerate (input , 1 ):
330+ if line == header :
331+ qldoc_start = lineno
332+ if line .startswith ("/**" ) and lineno - 1 != qldoc_start :
333+ qldoc_start = lineno
334+ if line .endswith (" */\n " ):
335+ qldoc_end = lineno + 1
336+ elif line .startswith (f"class { cls } " ):
337+ class_start = lineno
338+ break
339+ assert class_start , stub_file
340+ assert bool (qldoc_start ) == bool (qldoc_end ), stub_file
341+ if not qldoc_start or qldoc_end != class_start :
342+ qldoc_start = class_start
343+ input .seek (0 )
344+ with open (tmp , 'w' ) as output :
345+ for lineno , line in enumerate (input , 1 ):
346+ if lineno == qldoc_start :
347+ print (header , end = '' , file = output )
348+ print (qldoc , file = output )
349+ if lineno < qldoc_start or lineno >= class_start :
350+ print (line , end = '' , file = output )
351+ tmp .rename (stub_file )
316352
317353
318354def generate (opts , renderer ):
@@ -362,9 +398,13 @@ def generate(opts, renderer):
362398 for c in data .classes .values ():
363399 path = _get_path (c )
364400 stub_file = stub_out / path
401+ base_import = get_import (out / path , opts .root_dir )
402+ stub = _get_stub (c , base_import , generated_import_prefix )
365403 if not renderer .is_customized_stub (stub_file ):
366- base_import = get_import (out / path , opts .root_dir )
367- renderer .render (_get_stub (c , base_import , generated_import_prefix ), stub_file )
404+ renderer .render (stub , stub_file )
405+ else :
406+ qldoc = renderer .render_str (stub , template = 'ql_stub_class_qldoc' )
407+ _patch_class_qldocs (c .name , qldoc , stub_file )
368408
369409 # for example path/to/elements -> path/to/elements.qll
370410 renderer .render (ql .ImportList ([i for name , i in imports .items () if not classes [name ].ql_internal ]),
0 commit comments