Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use emmylua_parser::{
BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr,
LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableField, LuaVarExpr, PathTrait,
BinaryOperator, LuaAssignStat, LuaAstNode, LuaExpr, LuaFuncStat, LuaIndexExpr, LuaIndexKey,
LuaLocalFuncStat, LuaLocalStat, LuaNameExpr, LuaTableExpr, LuaTableField, LuaVarExpr,
PathTrait,
};

use crate::{
InFiled, InferFailReason, LuaSemanticDeclId, LuaTypeCache, LuaTypeOwner,
InFiled, InferFailReason, LuaMemberKey, LuaSemanticDeclId, LuaTypeCache, LuaTypeOwner,
compilation::analyzer::{
common::{add_member, bind_type},
unresolve::{UnResolveDecl, UnResolveMember},
},
db_index::{LuaDeclId, LuaMemberId, LuaMemberOwner, LuaType},
db_index::{LuaDeclId, LuaMember, LuaMemberFeature, LuaMemberId, LuaMemberOwner, LuaType},
};

use super::LuaAnalyzer;
Expand Down Expand Up @@ -449,7 +450,53 @@ pub fn analyze_local_func_stat(
Some(())
}

fn register_expr_key_member(analyzer: &mut LuaAnalyzer, field: &LuaTableField) {
// Register expression-key members early so table-decl inference (and pairs)
// can see them even when the table itself has no explicit generic type.
let Some(field_key) = field.get_field_key() else {
return;
};
let LuaIndexKey::Expr(_) = &field_key else {
return;
};
let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id);
if analyzer
.db
.get_member_index()
.get_member(&member_id)
.is_some()
{
return;
}
let cache = analyzer
.context
.infer_manager
.get_infer_cache(analyzer.file_id);
let Ok(member_key) = LuaMemberKey::from_index_key(analyzer.db, cache, &field_key) else {
return;
};
if matches!(member_key, LuaMemberKey::ExprType(ref typ) if typ.is_unknown()) {
return;
}
let Some(table_expr) = field.get_parent::<LuaTableExpr>() else {
return;
};
let owner_id = LuaMemberOwner::Element(InFiled::new(analyzer.file_id, table_expr.get_range()));
let decl_feature = if analyzer.context.metas.contains(&analyzer.file_id) {
LuaMemberFeature::MetaDefine
} else {
LuaMemberFeature::FileDefine
};
let member = LuaMember::new(member_id, member_key, decl_feature, None);
analyzer
.db
.get_member_index_mut()
.add_member(owner_id, member);
}

pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> Option<()> {
register_expr_key_member(analyzer, &field);

if field.is_assign_field() {
let value_expr = field.get_value_expr()?;
let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::sync::Arc;

use crate::{LuaType, LuaUnionType, VirtualWorkspace};
Expand Down Expand Up @@ -104,6 +105,69 @@ mod test {
assert_eq!(b, LuaType::Integer);
}

#[test]
fn test_enum_key_pairs() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
--- @enum Severity
local severity = {
ERROR = 1,
WARN = 2,
INFO = 3,
HINT = 4,
}

local severities = {
[severity.ERROR] = 1,
[severity.WARN] = 2,
[severity.INFO] = 3,
[severity.HINT] = 4,
}

for k in pairs(severities) do
key = k
end
"#,
);

let key_ty = ws.expr_ty("key");
let LuaType::Union(union) = key_ty else {
panic!("expected enum key union, got {:?}", key_ty);
};
let set = union.into_set();
let expected: HashSet<_> = vec![
LuaType::IntegerConst(1),
LuaType::IntegerConst(2),
LuaType::IntegerConst(3),
LuaType::IntegerConst(4),
]
.into_iter()
.collect();
assert_eq!(set, expected);
}

#[test]
fn test_pairs_expr_key_type() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
local key = tostring(1)
local t = {
[key] = 1,
}

for k in pairs(t) do
key_out = k
end
"#,
);

assert_eq!(ws.expr_ty("key_out"), LuaType::String);
}

#[test]
fn test_issue_291() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,45 @@ mod test {
let ty = ws.expr_ty("A");
assert_eq!(ws.humanize_type(ty), "(number|string)");
}

#[test]
fn test_table_expr_key_string() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
local key = tostring(1)
local t = { [key] = 1 }
value = t[key]
"#,
);

let value_ty = ws.expr_ty("value");
assert!(
matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)),
"expected integer type, got {:?}",
value_ty
);
}

#[test]
fn test_table_expr_key_doc_const() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@type 'field'
local key = "field"
local t = { [key] = 1 }
value = t[key]
"#,
);

let value_ty = ws.expr_ty("value");
assert!(
matches!(value_ty, LuaType::Integer | LuaType::IntegerConst(_)),
"expected integer type, got {:?}",
value_ty
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ fn table_generic_tpl_pattern_member_owner_match(
let key_type = match k {
LuaMemberKey::Integer(i) => LuaType::IntegerConst(i),
LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()),
LuaMemberKey::ExprType(typ) => typ,
_ => continue,
};

Expand Down