@@ -146,7 +146,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str =
146146 return ql .Property (** args )
147147
148148
149- def get_ql_class (cls : schema .Class ):
149+ def get_ql_class (cls : schema .Class ) -> ql . Class :
150150 pragmas = {k : True for k in cls .pragmas if k .startswith ("ql" )}
151151 prev_child = ""
152152 properties = []
@@ -228,6 +228,10 @@ def format(codeql, files):
228228 log .debug (line .strip ())
229229
230230
231+ def _get_path (cls : schema .Class ) -> pathlib .Path :
232+ return pathlib .Path (cls .group or "" , cls .name ).with_suffix (".qll" )
233+
234+
231235def _get_all_properties (cls : schema .Class , lookup : typing .Dict [str , schema .Class ],
232236 already_seen : typing .Optional [typing .Set [int ]] = None ) -> \
233237 typing .Iterable [typing .Tuple [schema .Class , schema .Property ]]:
@@ -283,6 +287,29 @@ def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class
283287 cls , lookup )
284288
285289
290+ def _get_stub (cls : schema .Class , base_import : str ) -> ql .Stub :
291+ if isinstance (cls .ipa , schema .IpaInfo ):
292+ if cls .ipa .from_class is not None :
293+ accessors = [
294+ ql .IpaUnderlyingAccessor (
295+ argument = "Entity" ,
296+ type = _to_db_type (cls .ipa .from_class ),
297+ constructorparams = ["result" ]
298+ )
299+ ]
300+ elif cls .ipa .on_arguments is not None :
301+ accessors = [
302+ ql .IpaUnderlyingAccessor (
303+ argument = inflection .camelize (arg ),
304+ type = _to_db_type (type ),
305+ constructorparams = ["result" if a == arg else "_" for a in cls .ipa .on_arguments ]
306+ ) for arg , type in cls .ipa .on_arguments .items ()
307+ ]
308+ else :
309+ accessors = []
310+ return ql .Stub (name = cls .name , base_import = base_import , ipa_accessors = accessors )
311+
312+
286313def generate (opts , renderer ):
287314 input = opts .schema
288315 out = opts .ql_output
@@ -323,10 +350,13 @@ def generate(opts, renderer):
323350 qll = out / c .path .with_suffix (".qll" )
324351 c .imports = [imports [t ] for t in get_classes_used_by (c )]
325352 renderer .render (c , qll )
326- stub_file = stub_out / c .path .with_suffix (".qll" )
353+
354+ for c in data .classes .values ():
355+ path = _get_path (c )
356+ stub_file = stub_out / path
327357 if not renderer .is_customized_stub (stub_file ):
328- stub = ql . Stub ( name = c . name , base_import = get_import (qll , opts .swift_dir ) )
329- renderer .render (stub , stub_file )
358+ base_import = get_import (out / path , opts .swift_dir )
359+ renderer .render (_get_stub ( c , base_import ) , stub_file )
330360
331361 # for example path/to/elements -> path/to/elements.qll
332362 renderer .render (ql .ImportList ([i for name , i in imports .items () if not classes [name ].ql_internal ]),
0 commit comments