diff --git a/ls/src/features/completion.rs b/ls/src/features/completion.rs index 39b8a9bd4..6ed11abd1 100644 --- a/ls/src/features/completion.rs +++ b/ls/src/features/completion.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{ CompletionContext, CompletionItem, CompletionItemKind, CompletionItemLabelDetails, CompletionTriggerKind, InsertTextFormat, @@ -75,7 +77,7 @@ const CONDITION_SUGGESTIONS: [(&str, Option<&str>); 16] = [ ]; pub fn completion( - document: &Document, + document: Arc, pos: Position, context: Option, ) -> Option> { diff --git a/ls/src/features/diagnostics.rs b/ls/src/features/diagnostics.rs index ce23cced2..38f7d485a 100644 --- a/ls/src/features/diagnostics.rs +++ b/ls/src/features/diagnostics.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{ Diagnostic, DiagnosticRelatedInformation, Location, Range, }; @@ -23,7 +25,7 @@ pub struct Patch { /// Returns a diagnostic vector for the given source code. #[allow(unused_variables)] -pub fn diagnostics(document: &Document) -> Vec { +pub fn diagnostics(document: Arc) -> Vec { #[allow(unused_mut)] let mut diagnostics: Vec = Vec::new(); @@ -40,7 +42,7 @@ pub fn diagnostics(document: &Document) -> Vec { /// comprehensive feedback including type checking, semantic analysis, /// and pattern validation - not just syntax errors. #[cfg(feature = "full-compiler")] -pub fn compiler_diagnostics(document: &Document) -> Vec { +pub fn compiler_diagnostics(document: Arc) -> Vec { let source_code = SourceCode::from(document.text.as_str()) .with_origin(document.uri.clone()); diff --git a/ls/src/features/document_highlight.rs b/ls/src/features/document_highlight.rs index 8585b73ec..d8a91e3b7 100644 --- a/ls/src/features/document_highlight.rs +++ b/ls/src/features/document_highlight.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{ DocumentHighlight, DocumentHighlightKind, Position, }; @@ -17,7 +19,7 @@ use crate::utils::position::{node_to_range, token_to_range}; /// specified position is contained in a symbol, the response contains the /// ranges of all occurrences of that symbol in the source code. pub fn document_highlight( - document: &Document, + document: Arc, pos: Position, ) -> Option> { let cst = &document.cst; diff --git a/ls/src/features/document_symbol.rs b/ls/src/features/document_symbol.rs index 5f4b9f1f8..78e937d87 100644 --- a/ls/src/features/document_symbol.rs +++ b/ls/src/features/document_symbol.rs @@ -1,10 +1,13 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{DocumentSymbol, SymbolKind}; use yara_x_parser::ast::{Item, WithSpan, AST}; use crate::document::Document; -pub fn document_symbol(document: &Document, ast: AST) -> Vec { +pub fn document_symbol(document: Arc) -> Vec { let line_index = &document.line_index; + let ast = AST::new(document.text.as_bytes(), document.cst.iter()); let mut symbols = Vec::new(); for item in ast.items { match item { diff --git a/ls/src/features/goto.rs b/ls/src/features/goto.rs index cd87f6828..539037ee9 100644 --- a/ls/src/features/goto.rs +++ b/ls/src/features/goto.rs @@ -1,4 +1,5 @@ use std::path::PathBuf; +use std::sync::Arc; use async_lsp::lsp_types::{Location, Position, Url}; use yara_x_parser::cst::SyntaxKind; @@ -13,7 +14,7 @@ use crate::utils::position::node_to_range; /// Given a position that points some identifier, returns the range /// of source code that contains the definition of that identifier. pub fn go_to_definition( - document: &Document, + document: Arc, pos: Position, ) -> Option { let token = ident_at_position(&document.cst, pos)?; @@ -38,7 +39,7 @@ pub fn go_to_definition( } fn go_to_rule_definition( - document: &Document, + document: Arc, ident: &str, ) -> Option { // Check if the rule is defined in the current document @@ -82,9 +83,9 @@ fn go_to_rule_definition( }; let uri = Url::from_file_path(abs_included_path).ok()?; - let document = Document::read(uri).ok()?; + let document = Arc::new(Document::read(uri).ok()?); - if let Some(location) = go_to_rule_definition(&document, ident) { + if let Some(location) = go_to_rule_definition(document, ident) { return Some(location); } } diff --git a/ls/src/features/hover.rs b/ls/src/features/hover.rs index a4ad74de1..a3eff42a9 100644 --- a/ls/src/features/hover.rs +++ b/ls/src/features/hover.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{ HoverContents, MarkupContent, MarkupKind, Position, }; @@ -71,7 +73,7 @@ impl RuleHoverBuilder { } } -pub fn hover(document: &Document, pos: Position) -> Option { +pub fn hover(document: Arc, pos: Position) -> Option { // Find the token at the position where the user is hovering. let token = token_at_position(&document.cst, pos)?; diff --git a/ls/src/features/references.rs b/ls/src/features/references.rs index a4e25c2a8..6101aa9ef 100644 --- a/ls/src/features/references.rs +++ b/ls/src/features/references.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{Position, Range}; use yara_x_parser::cst::SyntaxKind; @@ -11,7 +13,7 @@ use crate::utils::position::{node_to_range, token_to_range}; /// Finds all references of a symbol at the given position in the text. pub fn find_references( - document: &Document, + document: Arc, pos: Position, ) -> Option> { let cst = &document.cst; diff --git a/ls/src/features/selection_range.rs b/ls/src/features/selection_range.rs index aa48c4ace..3b7bfed2e 100644 --- a/ls/src/features/selection_range.rs +++ b/ls/src/features/selection_range.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_lsp::lsp_types::{Position, SelectionRange}; use yara_x_parser::cst::{Immutable, Node}; @@ -9,7 +11,7 @@ use crate::utils::position::{node_to_range, token_to_range}; /// Provides selection ranges from the given positions in the text /// based on the given CST of this document. pub fn selection_range( - document: &Document, + document: Arc, positions: Vec, ) -> Option> { let mut result: Vec = Vec::new(); diff --git a/ls/src/features/semantic_tokens.rs b/ls/src/features/semantic_tokens.rs index 9b11d33b8..040c1622a 100644 --- a/ls/src/features/semantic_tokens.rs +++ b/ls/src/features/semantic_tokens.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::sync::Arc; use async_lsp::lsp_types; use async_lsp::lsp_types::{ @@ -178,7 +179,7 @@ struct SemanticTokensIter { } impl SemanticTokensIter { - fn new(document: &Document, range: Option) -> Self { + fn new(document: Arc, range: Option) -> Self { let first_token = if let Some(range) = range { document.cst.root().token_at_position::(( range.start.line as usize, @@ -330,7 +331,7 @@ impl Iterator for SemanticTokensIter { /// An optional range can be specified, in which case only the tokens in that /// range will be returned. pub fn semantic_tokens( - document: &Document, + document: Arc, range: Option, ) -> SemanticTokens { let tokens = SemanticTokensIter::new(document, range); diff --git a/ls/src/server.rs b/ls/src/server.rs index fbf6afc58..ebb397e84 100644 --- a/ls/src/server.rs +++ b/ls/src/server.rs @@ -9,6 +9,7 @@ defines how the server should process various LSP requests and notifications. use std::collections::HashMap; use std::io::Cursor; use std::ops::ControlFlow; +use std::sync::Arc; use async_lsp::lsp_types::request::{ Request, SemanticTokensFullRequest, SemanticTokensRangeRequest, @@ -39,7 +40,6 @@ use async_lsp::{ClientSocket, LanguageClient, LanguageServer, ResponseError}; use futures::future::BoxFuture; use yara_x_fmt::Indentation; -use yara_x_parser::ast::AST; use crate::features::code_action::code_actions; use crate::features::completion::completion; @@ -58,7 +58,7 @@ use crate::features::semantic_tokens::{ use crate::document::Document; pub struct DocumentStore { - documents: HashMap, + documents: HashMap>, } impl DocumentStore { @@ -66,15 +66,15 @@ impl DocumentStore { Self { documents: HashMap::new() } } - fn get(&self, url: &Url) -> Option<&Document> { - self.documents.get(url) + fn get(&self, url: &Url) -> Option> { + self.documents.get(url).cloned() } - fn insert(&mut self, url: Url, document: Document) { + fn insert(&mut self, url: Url, document: Arc) { self.documents.insert(url, document); } - fn remove(&mut self, url: &Url) -> Option { + fn remove(&mut self, url: &Url) -> Option> { self.documents.remove(url) } } @@ -201,10 +201,10 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let result = hover(document, position) - .map(|contents| Hover { contents, range: None }); - - Box::pin(async move { Ok(result) }) + Box::pin(async move { + Ok(hover(document, position) + .map(|contents| Hover { contents, range: None })) + }) } /// This method is called when the user requests to go to the definition @@ -224,10 +224,10 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let definition = go_to_definition(document, position) - .map(GotoDefinitionResponse::Scalar); - - Box::pin(async move { Ok(definition) }) + Box::pin(async move { + Ok(go_to_definition(document, position) + .map(GotoDefinitionResponse::Scalar)) + }) } /// This method is called when the user requests to find all references @@ -246,15 +246,14 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let references = match find_references(document, position) { - Some(references) => references - .into_iter() - .map(|range| Location { uri: uri.clone(), range }) - .collect(), - None => return Box::pin(async { Ok(None) }), - }; - - Box::pin(async move { Ok(Some(references)) }) + Box::pin(async move { + Ok(find_references(document, position).map(|references| { + references + .into_iter() + .map(|range| Location { uri: uri.clone(), range }) + .collect() + })) + }) } /// This method is called when the user requests code actions for a range. @@ -295,10 +294,10 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let completions = completion(document, position, context) - .map(CompletionResponse::Array); - - Box::pin(async move { Ok(completions) }) + Box::pin(async move { + Ok(completion(document, position, context) + .map(CompletionResponse::Array)) + }) } /// This method is called when the user requests to highlight occurrences @@ -319,9 +318,7 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let highlights = document_highlight(document, position); - - Box::pin(async move { Ok(highlights) }) + Box::pin(async move { Ok(document_highlight(document, position)) }) } /// This method is called when the client requests a list of all symbols @@ -340,12 +337,9 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let ast = AST::new(document.text.as_bytes(), document.cst.iter()); - let symbols = document_symbol(document, ast); - - Box::pin( - async move { Ok(Some(DocumentSymbolResponse::Nested(symbols))) }, - ) + Box::pin(async move { + Ok(Some(DocumentSymbolResponse::Nested(document_symbol(document)))) + }) } /// This method is called to provide semantic highlighting for the document. @@ -364,9 +358,11 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let tokens = semantic_tokens(document, None); - - Box::pin(async move { Ok(Some(SemanticTokensResult::Tokens(tokens))) }) + Box::pin(async move { + Ok(Some(SemanticTokensResult::Tokens(semantic_tokens( + document, None, + )))) + }) } /// This method is called to provide semantic highlighting for a specific @@ -388,11 +384,12 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let tokens = semantic_tokens(document, Some(range)); - - Box::pin( - async move { Ok(Some(SemanticTokensRangeResult::Tokens(tokens))) }, - ) + Box::pin(async move { + Ok(Some(SemanticTokensRangeResult::Tokens(semantic_tokens( + document, + Some(range), + )))) + }) } /// This method is called when the user wants to rename a symbol. @@ -410,12 +407,14 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let changes = rename(&document.cst, params.new_name, position) - .map(|changes| HashMap::from([(uri, changes)])) - .map(WorkspaceEdit::new) - .unwrap_or_default(); + Box::pin(async move { + let changes = rename(&document.cst, params.new_name, position) + .map(|changes| HashMap::from([(uri, changes)])) + .map(WorkspaceEdit::new) + .unwrap_or_default(); - Box::pin(async move { Ok(Some(changes)) }) + Ok(Some(changes)) + }) } /// This method is called to determine the range of the symbol at the @@ -434,9 +433,9 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let ranges = selection_range(document, params.positions); - - Box::pin(async move { Ok(ranges) }) + Box::pin( + async move { Ok(selection_range(document, params.positions)) }, + ) } /// This method is called to provide diagnostic information for a document. @@ -485,28 +484,34 @@ impl LanguageServer for YARALanguageServer { None => return Box::pin(async { Ok(None) }), }; - let src = document.text.as_str(); - let line_count = src.lines().count() as u32; - let input = Cursor::new(src); - let mut output = Vec::new(); - - let indentation = if params.options.insert_spaces { - Indentation::Spaces(params.options.tab_size as usize) - } else { - Indentation::Tabs - }; - - let formatter = yara_x_fmt::Formatter::new().indentation(indentation); - - let result = match formatter.format(input, &mut output) { - Ok(changed) if changed => Some(vec![TextEdit::new( - Range::new(Position::new(0, 0), Position::new(line_count, 0)), - output.try_into().unwrap(), - )]), - _ => None, - }; + Box::pin(async move { + let src = document.text.as_str(); + let line_count = src.lines().count() as u32; + let input = Cursor::new(src); + let mut output = Vec::new(); + + let indentation = if params.options.insert_spaces { + Indentation::Spaces(params.options.tab_size as usize) + } else { + Indentation::Tabs + }; + + let formatter = + yara_x_fmt::Formatter::new().indentation(indentation); + + let result = match formatter.format(input, &mut output) { + Ok(changed) if changed => Some(vec![TextEdit::new( + Range::new( + Position::new(0, 0), + Position::new(line_count, 0), + ), + output.try_into().unwrap(), + )]), + _ => None, + }; - Box::pin(async move { Ok(result) }) + Ok(result) + }) } /// This method is called when a document is opened. @@ -519,7 +524,7 @@ impl LanguageServer for YARALanguageServer { ) -> Self::NotifyResult { let uri = params.text_document.uri; let text = params.text_document.text; - let document = Document::new(uri.clone(), text); + let document = Arc::new(Document::new(uri.clone(), text)); self.documents.insert(uri.clone(), document); self.publish_diagnostics(&uri); ControlFlow::Continue(()) @@ -535,7 +540,7 @@ impl LanguageServer for YARALanguageServer { ) -> Self::NotifyResult { if let Some(text) = params.text { let uri = params.text_document.uri; - let document = Document::new(uri.clone(), text); + let document = Arc::new(Document::new(uri.clone(), text)); self.documents.insert(uri.clone(), document); self.publish_diagnostics(&uri); } @@ -553,7 +558,7 @@ impl LanguageServer for YARALanguageServer { ) -> Self::NotifyResult { let uri = params.text_document.uri; for change in params.content_changes.into_iter() { - let document = Document::new(uri.clone(), change.text); + let document = Arc::new(Document::new(uri.clone(), change.text)); self.documents.insert(uri.clone(), document); } self.publish_diagnostics(&uri); @@ -606,7 +611,7 @@ impl YARALanguageServer { let _ = self.client.publish_diagnostics( PublishDiagnosticsParams { uri: uri.clone(), - diagnostics: diagnostics(document), + diagnostics: diagnostics(document.clone()), version: None, }, ); diff --git a/parser/src/cst/mod.rs b/parser/src/cst/mod.rs index 0df8a6679..c447e7414 100644 --- a/parser/src/cst/mod.rs +++ b/parser/src/cst/mod.rs @@ -248,13 +248,13 @@ impl rowan::Language for YARA { /// NOTE: This API is still unstable and should not be used by third-party code. #[doc(hidden)] pub struct CST { - tree: rowan::SyntaxNode, + root: rowan::GreenNode, errors: Vec<(Span, String)>, } impl Debug for CST { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:#?}", self.tree)?; + write!(f, "{:#?}", self.root())?; if !self.errors.is_empty() { writeln!(f, "\nERRORS:")?; for (span, err) in &self.errors { @@ -271,13 +271,14 @@ impl CST { /// The node is initially immutable, but it can be converted into a mutable /// one by calling [`Node::into_mut`]. pub fn root(&self) -> Node { - Node::new(self.tree.clone()) + Node::new(rowan::SyntaxNode::new_root(self.root.clone())) } /// Returns the parsed source code as an iterator of [`Event`]. pub fn iter(&self) -> impl Iterator + '_ { CSTIter { - iter: self.tree.preorder_with_tokens(), + iter: rowan::SyntaxNode::new_root(self.root.clone()) + .preorder_with_tokens(), errors: self.errors.iter().cloned(), } } @@ -342,10 +343,7 @@ where } } - Ok(Self { - tree: rowan::SyntaxNode::new_root(builder.finish()), - errors, - }) + Ok(Self { root: builder.finish(), errors }) } } @@ -932,12 +930,18 @@ impl From> for rowan::SyntaxElement { /// /// NOTE: This API is still unstable and should not be used by third-party code. #[doc(hidden)] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq)] pub struct Node { inner: rowan::SyntaxNode, _mutability: PhantomData, } +impl Debug for Node { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self.inner) + } +} + impl Node { fn new(inner: rowan::SyntaxNode) -> Self { Self { inner, _mutability: PhantomData }