@@ -64,6 +64,11 @@ public class PythonNextgenClientCodegen extends AbstractPythonCodegen implements
6464
6565 private String testFolder ;
6666
67+ // map of set (model imports)
68+ private HashMap <String , HashSet <String >> circularImports = new HashMap <>();
69+ // map of codegen models
70+ private HashMap <String , CodegenModel > codegenModelMap = new HashMap <>();
71+
6772 public PythonNextgenClientCodegen () {
6873 super ();
6974
@@ -404,14 +409,17 @@ public String getTypeDeclaration(Schema p) {
404409 * @param typingImports typing imports
405410 * @param pydantic pydantic imports
406411 * @param datetimeImports datetime imports
412+ * @param modelImports model imports
413+ * @param classname class name
407414 * @return pydantic type
408415 *
409416 */
410417 private String getPydanticType (CodegenParameter cp ,
411418 Set <String > typingImports ,
412419 Set <String > pydanticImports ,
413420 Set <String > datetimeImports ,
414- Set <String > modelImports ) {
421+ Set <String > modelImports ,
422+ String classname ) {
415423 if (cp == null ) {
416424 // if codegen parameter (e.g. map/dict of undefined type) is null, default to string
417425 LOGGER .warn ("Codegen property is null (e.g. map/dict of undefined type). Default to typing.Any." );
@@ -432,11 +440,12 @@ private String getPydanticType(CodegenParameter cp,
432440 }
433441 pydanticImports .add ("conlist" );
434442 return String .format (Locale .ROOT , "conlist(%s%s)" ,
435- getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports ),
443+ getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports , classname ),
436444 constraints );
437445 } else if (cp .isMap ) {
438446 typingImports .add ("Dict" );
439- return String .format (Locale .ROOT , "Dict[str, %s]" , getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports ));
447+ return String .format (Locale .ROOT , "Dict[str, %s]" ,
448+ getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports , classname ));
440449 } else if (cp .isString || cp .isBinary || cp .isByteArray ) {
441450 if (cp .hasValidation ) {
442451 List <String > fieldCustomization = new ArrayList <>();
@@ -612,7 +621,7 @@ private String getPydanticType(CodegenParameter cp,
612621 CodegenMediaType cmt = contents .get (key );
613622 // TODO process the first one only at the moment
614623 if (cmt != null )
615- return getPydanticType (cmt .getSchema (), typingImports , pydanticImports , datetimeImports , modelImports );
624+ return getPydanticType (cmt .getSchema (), typingImports , pydanticImports , datetimeImports , modelImports , classname );
616625 }
617626 throw new RuntimeException ("Error! Failed to process getPydanticType when getting the content: " + cp );
618627 } else {
@@ -627,14 +636,17 @@ private String getPydanticType(CodegenParameter cp,
627636 * @param typingImports typing imports
628637 * @param pydantic pydantic imports
629638 * @param datetimeImports datetime imports
639+ * @param modelImports model imports
640+ * @param classname class name
630641 * @return pydantic type
631642 *
632643 */
633644 private String getPydanticType (CodegenProperty cp ,
634645 Set <String > typingImports ,
635646 Set <String > pydanticImports ,
636647 Set <String > datetimeImports ,
637- Set <String > modelImports ) {
648+ Set <String > modelImports ,
649+ String classname ) {
638650 if (cp == null ) {
639651 // if codegen property (e.g. map/dict of undefined type) is null, default to string
640652 LOGGER .warn ("Codegen property is null (e.g. map/dict of undefined type). Default to typing.Any." );
@@ -674,11 +686,11 @@ private String getPydanticType(CodegenProperty cp,
674686 pydanticImports .add ("conlist" );
675687 typingImports .add ("List" ); // for return type
676688 return String .format (Locale .ROOT , "conlist(%s%s)" ,
677- getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports ),
689+ getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports , classname ),
678690 constraints );
679691 } else if (cp .isMap ) {
680692 typingImports .add ("Dict" );
681- return String .format (Locale .ROOT , "Dict[str, %s]" , getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports ));
693+ return String .format (Locale .ROOT , "Dict[str, %s]" , getPydanticType (cp .items , typingImports , pydanticImports , datetimeImports , modelImports , classname ));
682694 } else if (cp .isString ) {
683695 if (cp .hasValidation ) {
684696 List <String > fieldCustomization = new ArrayList <>();
@@ -846,10 +858,24 @@ private String getPydanticType(CodegenProperty cp,
846858 typingImports .add ("Any" );
847859 return "Dict[str, Any]" ;
848860 } else if (!cp .isPrimitiveType || cp .isModel ) { // model
849- if (!cp .isCircularReference ) {
850- // skip import if it's a circular reference
861+ // skip import if it's a circular reference
862+ if (classname == null ) {
863+ // for parameter model, import directly
851864 hasModelsToImport = true ;
852865 modelImports .add (cp .dataType );
866+ } else {
867+ if (circularImports .containsKey (cp .dataType )) {
868+ if (circularImports .get (cp .dataType ).contains (classname )) {
869+ // cp.dataType import map of set contains this model (classname), don't import
870+ LOGGER .debug ("Skipped importing {} in {} due to circular import." , cp .dataType , classname );
871+ } else {
872+ // not circular import, so ok to import it
873+ hasModelsToImport = true ;
874+ modelImports .add (cp .dataType );
875+ }
876+ } else {
877+ LOGGER .error ("Failed to look up {} from the imports (map of set) of models." , cp .dataType );
878+ }
853879 }
854880 return cp .dataType ;
855881 } else {
@@ -871,7 +897,7 @@ public OperationsMap postProcessOperationsWithModels(OperationsMap objs, List<Mo
871897
872898 List <CodegenParameter > params = operation .allParams ;
873899 for (CodegenParameter param : params ) {
874- String typing = getPydanticType (param , typingImports , pydanticImports , datetimeImports , modelImports );
900+ String typing = getPydanticType (param , typingImports , pydanticImports , datetimeImports , modelImports , null );
875901 List <String > fields = new ArrayList <>();
876902 String firstField = "" ;
877903
@@ -923,7 +949,7 @@ public OperationsMap postProcessOperationsWithModels(OperationsMap objs, List<Mo
923949 // update typing import for operation return type
924950 if (!StringUtils .isEmpty (operation .returnType )) {
925951 String typing = getPydanticType (operation .returnProperty , typingImports ,
926- new TreeSet <>() /* skip pydantic import for return type */ , datetimeImports , modelImports );
952+ new TreeSet <>() /* skip pydantic import for return type */ , datetimeImports , modelImports , null );
927953 }
928954
929955 }
@@ -983,13 +1009,118 @@ public OperationsMap postProcessOperationsWithModels(OperationsMap objs, List<Mo
9831009 @ Override
9841010 public Map <String , ModelsMap > postProcessAllModels (Map <String , ModelsMap > objs ) {
9851011 final Map <String , ModelsMap > processed = super .postProcessAllModels (objs );
1012+
1013+ for (Map .Entry <String , ModelsMap > entry : objs .entrySet ()) {
1014+ // create hash map of codegen model
1015+ CodegenModel cm = ModelUtils .getModelByName (entry .getKey (), objs );
1016+ codegenModelMap .put (cm .classname , ModelUtils .getModelByName (entry .getKey (), objs ));
1017+ }
1018+
1019+ // create circular import
1020+ for (String m : codegenModelMap .keySet ()) {
1021+ createImportMapOfSet (m , codegenModelMap );
1022+ }
1023+
9861024 for (Map .Entry <String , ModelsMap > entry : processed .entrySet ()) {
9871025 entry .setValue (postProcessModelsMap (entry .getValue ()));
9881026 }
9891027
9901028 return processed ;
9911029 }
9921030
1031+ /**
1032+ * Update circularImports with the model name (key) and its imports gathered recursively
1033+ *
1034+ * @param modelName model name
1035+ * @param codegenModelMap a map of CodegenModel
1036+ */
1037+ void createImportMapOfSet (String modelName , Map <String , CodegenModel > codegenModelMap ) {
1038+ HashSet <String > imports = new HashSet <>();
1039+ circularImports .put (modelName , imports );
1040+
1041+ CodegenModel cm = codegenModelMap .get (modelName );
1042+
1043+ if (cm == null ) {
1044+ LOGGER .warn ("Failed to lookup model in createImportMapOfSet: " + modelName );
1045+ return ;
1046+ }
1047+
1048+ List <CodegenProperty > codegenProperties = null ;
1049+ if (cm .oneOf != null && !cm .oneOf .isEmpty ()) { // oneOf
1050+ codegenProperties = cm .getComposedSchemas ().getOneOf ();
1051+ } else if (cm .anyOf != null && !cm .anyOf .isEmpty ()) { // anyOF
1052+ codegenProperties = cm .getComposedSchemas ().getAnyOf ();
1053+ } else { // typical model
1054+ codegenProperties = cm .vars ;
1055+ }
1056+
1057+ for (CodegenProperty cp : codegenProperties ) {
1058+ String modelNameFromDataType = getModelNameFromDataType (cp );
1059+ if (modelNameFromDataType != null ) { // model
1060+ imports .add (modelNameFromDataType ); // update import
1061+ // go through properties or sub-schemas of the model recursively to identify more (model) import if any
1062+ updateImportsFromCodegenModel (modelNameFromDataType , codegenModelMap .get (modelNameFromDataType ), imports );
1063+ }
1064+ }
1065+ }
1066+
1067+ /**
1068+ * Update set of imports from codegen model recursivly
1069+ *
1070+ * @param modelName model name
1071+ * @param cm codegen model
1072+ * @param imports set of imports
1073+ */
1074+ public void updateImportsFromCodegenModel (String modelName , CodegenModel cm , Set <String > imports ) {
1075+ if (cm == null ) {
1076+ LOGGER .warn ("Failed to lookup model in createImportMapOfSet " + modelName );
1077+ return ;
1078+ }
1079+
1080+ List <CodegenProperty > codegenProperties = null ;
1081+ if (cm .oneOf != null && !cm .oneOf .isEmpty ()) { // oneOfValidationError
1082+ codegenProperties = cm .getComposedSchemas ().getOneOf ();
1083+ } else if (cm .anyOf != null && !cm .anyOf .isEmpty ()) { // anyOF
1084+ codegenProperties = cm .getComposedSchemas ().getAnyOf ();
1085+ } else { // typical model
1086+ codegenProperties = cm .vars ;
1087+ }
1088+
1089+ for (CodegenProperty cp : codegenProperties ) {
1090+ String modelNameFromDataType = getModelNameFromDataType (cp );
1091+ if (modelNameFromDataType != null ) { // model
1092+ if (modelName .equals (modelNameFromDataType )) { // self referencing
1093+ continue ;
1094+ } else if (imports .contains (modelNameFromDataType )) { // circular import
1095+ continue ;
1096+ } else {
1097+ imports .add (modelNameFromDataType ); // update import
1098+ // go through properties of the model recursively to identify more (model) import if any
1099+ updateImportsFromCodegenModel (modelNameFromDataType , codegenModelMap .get (modelNameFromDataType ), imports );
1100+ }
1101+ }
1102+ }
1103+ }
1104+
1105+ /**
1106+ * Returns the model name (if any) from data type of codegen property.
1107+ * Returns null if it's not a model.
1108+ *
1109+ * @param cp Codegen property
1110+ * @return model name
1111+ */
1112+ private String getModelNameFromDataType (CodegenProperty cp ) {
1113+ if (cp .isArray ) {
1114+ return getModelNameFromDataType (cp .items );
1115+ } else if (cp .isMap ) {
1116+ return getModelNameFromDataType (cp .items );
1117+ } else if (!cp .isPrimitiveType || cp .isModel ) {
1118+ return cp .dataType ;
1119+ } else {
1120+ return null ;
1121+ }
1122+ }
1123+
9931124 private ModelsMap postProcessModelsMap (ModelsMap objs ) {
9941125 // process enum in models
9951126 objs = postProcessModelsEnum (objs );
@@ -1044,7 +1175,7 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
10441175
10451176 //loop through properties/schemas to set up typing, pydantic
10461177 for (CodegenProperty cp : codegenProperties ) {
1047- String typing = getPydanticType (cp , typingImports , pydanticImports , datetimeImports , modelImports );
1178+ String typing = getPydanticType (cp , typingImports , pydanticImports , datetimeImports , modelImports , model . classname );
10481179 List <String > fields = new ArrayList <>();
10491180 String firstField = "" ;
10501181
0 commit comments