diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 02aabe390..f5f513a45 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -1,15 +1,16 @@ use emmylua_parser::{ - BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, - LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableField, LuaVarExpr, PathTrait, + BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaIndexKey, + LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableExpr, LuaTableField, LuaVarExpr, + PathTrait, }; use crate::{ - InFiled, InferFailReason, LuaSemanticDeclId, LuaTypeCache, LuaTypeOwner, + InFiled, InferFailReason, LuaMemberKey, LuaSemanticDeclId, LuaTypeCache, LuaTypeOwner, compilation::analyzer::{ common::{add_member, bind_type}, unresolve::{UnResolveDecl, UnResolveMember}, }, - db_index::{LuaDeclId, LuaMemberId, LuaMemberOwner, LuaType}, + db_index::{LuaDeclId, LuaMember, LuaMemberFeature, LuaMemberId, LuaMemberOwner, LuaType}, }; use super::LuaAnalyzer; @@ -449,7 +450,53 @@ pub fn analyze_local_func_stat( Some(()) } +fn register_expr_key_member(analyzer: &mut LuaAnalyzer, field: &LuaTableField) { + // Register expression-key members early so table-decl inference (and pairs) + // can see them even when the table itself has no explicit generic type. + let Some(field_key) = field.get_field_key() else { + return; + }; + let LuaIndexKey::Expr(_) = &field_key else { + return; + }; + let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + if analyzer + .db + .get_member_index() + .get_member(&member_id) + .is_some() + { + return; + } + let cache = analyzer + .context + .infer_manager + .get_infer_cache(analyzer.file_id); + let Ok(member_key) = LuaMemberKey::from_index_key(analyzer.db, cache, &field_key) else { + return; + }; + if matches!(member_key, LuaMemberKey::ExprType(ref typ) if typ.is_unknown()) { + return; + } + let Some(table_expr) = field.get_parent::() else { + return; + }; + let owner_id = LuaMemberOwner::Element(InFiled::new(analyzer.file_id, table_expr.get_range())); + let decl_feature = if analyzer.context.metas.contains(&analyzer.file_id) { + LuaMemberFeature::MetaDefine + } else { + LuaMemberFeature::FileDefine + }; + let member = LuaMember::new(member_id, member_key, decl_feature, None); + analyzer + .db + .get_member_index_mut() + .add_member(owner_id, member); +} + pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> Option<()> { + register_expr_key_member(analyzer, &field); + if field.is_assign_field() { let value_expr = field.get_value_expr()?; let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); diff --git a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs index 9aebaf481..3cc8c5200 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/for_range_var_infer_test.rs @@ -1,5 +1,6 @@ #[cfg(test)] mod test { + use std::collections::HashSet; use std::sync::Arc; use crate::{LuaType, LuaUnionType, VirtualWorkspace}; @@ -104,6 +105,69 @@ mod test { assert_eq!(b, LuaType::Integer); } + #[test] + fn test_enum_key_pairs() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + --- @enum Severity + local severity = { + ERROR = 1, + WARN = 2, + INFO = 3, + HINT = 4, + } + + local severities = { + [severity.ERROR] = 1, + [severity.WARN] = 2, + [severity.INFO] = 3, + [severity.HINT] = 4, + } + + for k in pairs(severities) do + key = k + end + "#, + ); + + let key_ty = ws.expr_ty("key"); + let LuaType::Union(union) = key_ty else { + panic!("expected enum key union, got {:?}", key_ty); + }; + let set = union.into_set(); + let expected: HashSet<_> = vec![ + LuaType::IntegerConst(1), + LuaType::IntegerConst(2), + LuaType::IntegerConst(3), + LuaType::IntegerConst(4), + ] + .into_iter() + .collect(); + assert_eq!(set, expected); + } + + #[test] + fn test_pairs_expr_key_type() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local key = tostring(1) + local t = { + [key] = 1, + } + + for k in pairs(t) do + key_out = k + end + "#, + ); + + assert_eq!(ws.expr_ty("key_out"), LuaType::String); + } + #[test] fn test_issue_291() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 0474e4630..a6b22b828 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -126,4 +126,45 @@ mod test { let ty = ws.expr_ty("A"); assert_eq!(ws.humanize_type(ty), "(number|string)"); } + + #[test] + fn test_table_expr_key_string() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local key = tostring(1) + local t = { [key] = 1 } + value = t[key] + "#, + ); + + let value_ty = ws.expr_ty("value"); + assert!( + matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)), + "expected integer type, got {:?}", + value_ty + ); + } + + #[test] + fn test_table_expr_key_doc_const() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@type 'field' + local key = "field" + local t = { [key] = 1 } + value = t[key] + "#, + ); + + let value_ty = ws.expr_ty("value"); + assert!( + matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)), + "expected integer type, got {:?}", + value_ty + ); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 99c092f9e..9b2ea76cc 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -468,6 +468,7 @@ fn table_generic_tpl_pattern_member_owner_match( let key_type = match k { LuaMemberKey::Integer(i) => LuaType::IntegerConst(i), LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), + LuaMemberKey::ExprType(typ) => typ, _ => continue, };