@@ -17,7 +17,7 @@ use std::fmt;
1717
1818use hir_def:: {
1919 AdtId , ConstId , EnumId , EnumVariantId , FunctionId , HasModule , ItemContainerId , Lookup ,
20- ModuleDefId , ModuleId , StaticId , StructId , TraitId , TypeAliasId , attrs:: AttrFlags ,
20+ ModuleDefId , ModuleId , StaticId , StructId , TraitId , TypeAliasId , UnionId , attrs:: AttrFlags ,
2121 db:: DefDatabase , hir:: Pat , item_tree:: FieldsShape , signatures:: StaticFlags , src:: HasSource ,
2222} ;
2323use hir_expand:: {
@@ -77,6 +77,7 @@ pub enum IdentType {
7777 Structure ,
7878 Trait ,
7979 TypeAlias ,
80+ Union ,
8081 Variable ,
8182 Variant ,
8283}
@@ -94,6 +95,7 @@ impl fmt::Display for IdentType {
9495 IdentType :: Structure => "Structure" ,
9596 IdentType :: Trait => "Trait" ,
9697 IdentType :: TypeAlias => "Type alias" ,
98+ IdentType :: Union => "Union" ,
9799 IdentType :: Variable => "Variable" ,
98100 IdentType :: Variant => "Variant" ,
99101 } ;
@@ -146,9 +148,7 @@ impl<'a> DeclValidator<'a> {
146148 match adt {
147149 AdtId :: StructId ( struct_id) => self . validate_struct ( struct_id) ,
148150 AdtId :: EnumId ( enum_id) => self . validate_enum ( enum_id) ,
149- AdtId :: UnionId ( _) => {
150- // FIXME: Unions aren't yet supported by this validator.
151- }
151+ AdtId :: UnionId ( union_id) => self . validate_union ( union_id) ,
152152 }
153153 }
154154
@@ -383,6 +383,94 @@ impl<'a> DeclValidator<'a> {
383383 }
384384 }
385385
386+ fn validate_union ( & mut self , union_id : UnionId ) {
387+ // Check the union name.
388+ let data = self . db . union_signature ( union_id) ;
389+
390+ // rustc implementation excuses repr(C) since C unions predominantly don't
391+ // use camel case.
392+ let has_repr_c = AttrFlags :: repr ( self . db , union_id. into ( ) ) . is_some_and ( |repr| repr. c ( ) ) ;
393+ if !has_repr_c {
394+ self . create_incorrect_case_diagnostic_for_item_name (
395+ union_id,
396+ & data. name ,
397+ CaseType :: UpperCamelCase ,
398+ IdentType :: Union ,
399+ ) ;
400+ }
401+
402+ // Check the field names.
403+ self . validate_union_fields ( union_id) ;
404+ }
405+
406+ /// Check incorrect names for union fields.
407+ fn validate_union_fields ( & mut self , union_id : UnionId ) {
408+ let data = union_id. fields ( self . db ) ;
409+ let edition = self . edition ( union_id) ;
410+ let mut union_fields_replacements = data
411+ . fields ( )
412+ . iter ( )
413+ . filter_map ( |( _, field) | {
414+ to_lower_snake_case ( & field. name . display_no_db ( edition) . to_smolstr ( ) ) . map (
415+ |new_name| Replacement {
416+ current_name : field. name . clone ( ) ,
417+ suggested_text : new_name,
418+ expected_case : CaseType :: LowerSnakeCase ,
419+ } ,
420+ )
421+ } )
422+ . peekable ( ) ;
423+
424+ // XXX: Only look at sources if we do have incorrect names.
425+ if union_fields_replacements. peek ( ) . is_none ( ) {
426+ return ;
427+ }
428+
429+ let union_loc = union_id. lookup ( self . db ) ;
430+ let union_src = union_loc. source ( self . db ) ;
431+
432+ let Some ( union_fields_list) = union_src. value . record_field_list ( ) else {
433+ always ! (
434+ union_fields_replacements. peek( ) . is_none( ) ,
435+ "Replacements ({:?}) were generated for a union fields \
436+ which had no fields list: {:?}",
437+ union_fields_replacements. collect:: <Vec <_>>( ) ,
438+ union_src
439+ ) ;
440+ return ;
441+ } ;
442+ let mut union_fields_iter = union_fields_list. fields ( ) ;
443+ for field_replacement in union_fields_replacements {
444+ // We assume that parameters in replacement are in the same order as in the
445+ // actual params list, but just some of them (ones that named correctly) are skipped.
446+ let field = loop {
447+ if let Some ( field) = union_fields_iter. next ( ) {
448+ let Some ( field_name) = field. name ( ) else {
449+ continue ;
450+ } ;
451+ if field_name. as_name ( ) == field_replacement. current_name {
452+ break field;
453+ }
454+ } else {
455+ never ! (
456+ "Replacement ({:?}) was generated for a union field \
457+ which was not found: {:?}",
458+ field_replacement,
459+ union_src
460+ ) ;
461+ return ;
462+ }
463+ } ;
464+
465+ self . create_incorrect_case_diagnostic_for_ast_node (
466+ field_replacement,
467+ union_src. file_id ,
468+ & field,
469+ IdentType :: Field ,
470+ ) ;
471+ }
472+ }
473+
386474 fn validate_enum ( & mut self , enum_id : EnumId ) {
387475 // Check the enum name.
388476 let data = self . db . enum_signature ( enum_id) ;
0 commit comments